diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 9d7bfcaa..fbeefe6a 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -5,7 +5,7 @@ from uuid import uuid4 from primaite import getLogger from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.applications.application import Application, ApplicationOperatingState +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager _LOGGER = getLogger(__name__) @@ -23,7 +23,6 @@ class DatabaseClient(Application): server_ip_address: Optional[IPv4Address] = None server_password: Optional[str] = None - connections: Dict[str, Dict] = {} _query_success_tracker: Dict[str, bool] = {} def __init__(self, **kwargs): @@ -143,7 +142,7 @@ class DatabaseClient(Application): dest_ip_address=self.server_ip_address, dest_port=self.port, ) - self.connections.pop(connection_id) + self.remove_connection(connection_id=connection_id) self.sys_log.info( f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}" @@ -181,8 +180,6 @@ class DatabaseClient(Application): def run(self) -> None: """Run the DatabaseClient.""" super().run() - if self.operating_state == ApplicationOperatingState.RUNNING: - self.connect() def query(self, sql: str, connection_id: Optional[str] = None) -> bool: """ @@ -221,7 +218,8 @@ class DatabaseClient(Application): if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "connect_response": if payload["response"] is True: - self.connections[payload.get("connection_id")] = payload + # add connection + self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id) elif payload["type"] == "sql": query_id = payload.get("uuid") status_code = payload.get("status_code") diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 87959e9b..a1429e51 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -5,7 +5,6 @@ from typing import Optional from primaite import getLogger from primaite.game.science import simulate_trial from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -177,9 +176,9 @@ class DataManipulationBot(DatabaseClient): This is the core loop where the bot sequentially goes through the stages of the attack. """ - if self.operating_state != ApplicationOperatingState.RUNNING: + if not self._can_perform_action(): return - if self.server_ip_address and self.payload and self.operating_state: + if self.server_ip_address and self.payload: self.sys_log.info(f"{self.name}: Running") self._logon() self._perform_port_scan(p_of_success=self.port_scan_p_of_success) diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py new file mode 100644 index 00000000..e6c643ee --- /dev/null +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -0,0 +1,184 @@ +from enum import IntEnum +from ipaddress import IPv4Address +from typing import Optional + +from primaite import getLogger +from primaite.game.science import simulate_trial +from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.database_client import DatabaseClient + +_LOGGER = getLogger(__name__) + + +class DoSAttackStage(IntEnum): + """Enum representing the different stages of a Denial of Service attack.""" + + NOT_STARTED = 0 + "Attack not yet started." + + PORT_SCAN = 1 + "Attack is in discovery stage - checking if provided ip and port are open." + + ATTACKING = 2 + "Denial of Service attack is in progress." + + COMPLETED = 3 + "Attack is completed." + + +class DoSBot(DatabaseClient, Application): + """A bot that simulates a Denial of Service attack.""" + + target_ip_address: Optional[IPv4Address] = None + """IP address of the target service.""" + + target_port: Optional[Port] = None + """Port of the target service.""" + + payload: Optional[str] = None + """Payload to deliver to the target service as part of the denial of service attack.""" + + repeat: bool = False + """If true, the Denial of Service bot will keep performing the attack.""" + + attack_stage: DoSAttackStage = DoSAttackStage.NOT_STARTED + """Current stage of the DoS kill chain.""" + + port_scan_p_of_success: float = 0.1 + """Probability of port scanning being sucessful.""" + + dos_intensity: float = 0.25 + """How much of the max sessions will be used by the DoS when attacking.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.name = "DoSBot" + self.max_sessions = 1000 # override normal max sessions + + def set_original_state(self): + """Set the original state of the Denial of Service Bot.""" + _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}") + super().set_original_state() + vals_to_include = { + "target_ip_address", + "target_port", + "payload", + "repeat", + "attack_stage", + "max_sessions", + "port_scan_p_of_success", + "dos_intensity", + } + self._original_state.update(self.model_dump(include=vals_to_include)) + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") + super().reset_component_for_episode(episode) + + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + + rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run())) + + return rm + + def configure( + self, + target_ip_address: IPv4Address, + target_port: Optional[Port] = Port.POSTGRES_SERVER, + payload: Optional[str] = None, + repeat: bool = False, + max_sessions: int = 1000, + ): + """ + Configure the Denial of Service bot. + + :param: target_ip_address: The IP address of the Node containing the target service. + :param: target_port: The port of the target service. Optional - Default is `Port.HTTP` + :param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None` + :param: repeat: If True, the bot will maintain the attack. Optional - Default is `True` + :param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000 + """ + self.target_ip_address = target_ip_address + self.target_port = target_port + self.payload = payload + self.repeat = repeat + self.max_sessions = max_sessions + self.sys_log.info( + f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, {repeat=}." + ) + + def run(self): + """Run the Denial of Service Bot.""" + super().run() + self._application_loop() + + def _application_loop(self): + """ + The main application loop for the Denial of Service bot. + + The loop goes through the stages of a DoS attack. + """ + if not self._can_perform_action(): + return + + # DoS bot cannot do anything without a target + if not self.target_ip_address or not self.target_port: + self.sys_log.error( + f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}" + ) + return + + self.clear_connections() + self._perform_port_scan(p_of_success=self.port_scan_p_of_success) + self._perform_dos() + + if self.repeat and self.attack_stage is DoSAttackStage.ATTACKING: + self.attack_stage = DoSAttackStage.NOT_STARTED + else: + self.attack_stage = DoSAttackStage.COMPLETED + + def _perform_port_scan(self, p_of_success: Optional[float] = 0.1): + """ + Perform a simulated port scan to check for open SQL ports. + + Advances the attack stage to `PORT_SCAN` if successful. + + :param p_of_success: Probability of successful port scan, by default 0.1. + """ + if self.attack_stage == DoSAttackStage.NOT_STARTED: + # perform a port scan to identify that the SQL port is open on the server + if simulate_trial(p_of_success): + self.sys_log.info(f"{self.name}: Performing port scan") + # perform the port scan + port_is_open = True # Temporary; later we can implement NMAP port scan. + if port_is_open: + self.sys_log.info(f"{self.name}: ") + self.attack_stage = DoSAttackStage.PORT_SCAN + + def _perform_dos(self): + """ + Perform the Denial of Service attack. + + DoSBot does this by clogging up the available connections to a service. + """ + if not self.attack_stage == DoSAttackStage.PORT_SCAN: + return + self.attack_stage = DoSAttackStage.ATTACKING + self.server_ip_address = self.target_ip_address + self.port = self.target_port + + dos_sessions = int(float(self.max_sessions) * self.dos_intensity) + for i in range(dos_sessions): + self.connect() + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the bot, iterate through the application loop. + + :param timestep: The timestep value to update the bot's state. + """ + self._application_loop() diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 70a4e6cc..7d313068 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -45,7 +45,7 @@ class DatabaseService(Service): super().set_original_state() vals_to_include = { "password", - "connections", + "_connections", "backup_server", "latest_backup_directory", "latest_backup_file_name", @@ -55,7 +55,7 @@ class DatabaseService(Service): def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") - self.connections.clear() + self.clear_connections() super().reset_component_for_episode(episode) def configure_backup(self, backup_server: IPv4Address): @@ -225,9 +225,6 @@ class DatabaseService(Service): :param session_id: The session identifier. :return: True if the Status Code is 200, otherwise False. """ - if not super().receive(payload=payload, session_id=session_id, **kwargs): - return False - result = {"status_code": 500, "data": []} # if server service is down, return error diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 52655fa4..7faa5d32 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -20,9 +20,6 @@ class FTPClient(FTPServiceABC): RFC 959: https://datatracker.ietf.org/doc/html/rfc959 """ - connected: bool = False - """Keeps track of whether or not the FTP client is connected to an FTP server.""" - def __init__(self, **kwargs): kwargs["name"] = "FTPClient" kwargs["port"] = Port.FTP @@ -129,10 +126,7 @@ class FTPClient(FTPServiceABC): software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port ) - if payload.status_code == FTPStatusCode.OK: - self.connected = False - return True - return False + return payload.status_code == FTPStatusCode.OK def send_file( self, @@ -179,9 +173,9 @@ class FTPClient(FTPServiceABC): return False # check if FTP is currently connected to IP - self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) + self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) - if not self.connected: + if not len(self.connections): return False else: self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}") @@ -230,9 +224,9 @@ class FTPClient(FTPServiceABC): :type: dest_port: Optional[Port] """ # check if FTP is currently connected to IP - self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) + self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port) - if not self.connected: + if not len(self.connections): return False else: # send retrieve request @@ -286,6 +280,14 @@ class FTPClient(FTPServiceABC): self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}") return False + # if PORT succeeded, add the connection as an active connection list + if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK: + self.add_connection(connection_id=session_id, session_id=session_id) + + # if QUIT succeeded, remove the session from active connection list + if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK: + self.remove_connection(connection_id=session_id) + self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}") self._process_ftp_command(payload=payload, session_id=session_id) diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 6e6c1a48..585690b6 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -37,7 +37,7 @@ class FTPServer(FTPServiceABC): def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") - self.connections.clear() + self.clear_connections() super().reset_component_for_episode(episode) def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 52187e51..3155a4bd 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,5 +1,3 @@ -import copy -from datetime import datetime from enum import Enum from typing import Any, Dict, Optional @@ -42,9 +40,6 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." - _connections: Dict[str, Dict] = {} - "Active connections to the Service." - def __init__(self, **kwargs): super().__init__(**kwargs) @@ -103,11 +98,6 @@ class Service(IOSoftware): rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) return rm - @property - def connections(self) -> Dict[str, Dict]: - """Return the public version of connections.""" - return copy.copy(self._connections) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -123,56 +113,6 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible.value return state - def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool: - """ - Create a new connection to this service. - - Returns true if connection successfully created - - :param: connection_id: UUID of the connection to create - :type: string - """ - # if over or at capacity, set to overwhelmed - if len(self._connections) >= self.max_sessions: - self.health_state_actual = SoftwareHealthState.OVERWHELMED - self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") - return False - else: - # if service was previously overwhelmed, set to good because there is enough space for connections - if self.health_state_actual == SoftwareHealthState.OVERWHELMED: - self.health_state_actual = SoftwareHealthState.GOOD - - # check that connection already doesn't exist - if not self._connections.get(connection_id): - session_details = None - if session_id: - session_details = self._get_session_details(session_id) - self._connections[connection_id] = { - "ip_address": session_details.with_ip_address if session_details else None, - "time": datetime.now(), - } - self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") - return True - # connection with given id already exists - self.sys_log.error( - f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." - ) - return False - - def remove_connection(self, connection_id: str) -> bool: - """ - Remove a connection from this service. - - Returns true if connection successfully removed - - :param: connection_id: UUID of the connection to create - :type: string - """ - if self._connections.get(connection_id): - self._connections.pop(connection_id) - self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.") - return True - def stop(self) -> None: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 8746bdf3..b393ffd8 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -1,4 +1,6 @@ +import copy from abc import abstractmethod +from datetime import datetime from enum import Enum from ipaddress import IPv4Address from typing import Any, Dict, Optional @@ -206,6 +208,8 @@ class IOSoftware(Software): "Indicates if the software uses UDP protocol for communication. Default is True." port: Port "The port to which the software is connected." + _connections: Dict[str, Dict] = {} + "Active connections." def set_original_state(self): """Sets the original state.""" @@ -250,6 +254,65 @@ class IOSoftware(Software): return False return True + @property + def connections(self) -> Dict[str, Dict]: + """Return the public version of connections.""" + return copy.copy(self._connections) + + def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool: + """ + Create a new connection to this service. + + Returns true if connection successfully created + + :param: connection_id: UUID of the connection to create + :type: string + """ + # if over or at capacity, set to overwhelmed + if len(self._connections) >= self.max_sessions: + self.health_state_actual = SoftwareHealthState.OVERWHELMED + self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") + return False + else: + # if service was previously overwhelmed, set to good because there is enough space for connections + if self.health_state_actual == SoftwareHealthState.OVERWHELMED: + self.health_state_actual = SoftwareHealthState.GOOD + + # check that connection already doesn't exist + if not self._connections.get(connection_id): + session_details = None + if session_id: + session_details = self._get_session_details(session_id) + self._connections[connection_id] = { + "ip_address": session_details.with_ip_address if session_details else None, + "time": datetime.now(), + } + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + return True + # connection with given id already exists + self.sys_log.error( + f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." + ) + return False + + def remove_connection(self, connection_id: str) -> bool: + """ + Remove a connection from this service. + + Returns true if connection successfully removed + + :param: connection_id: UUID of the connection to create + :type: string + """ + if self.connections.get(connection_id): + self._connections.pop(connection_id) + self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.") + return True + + def clear_connections(self): + """Clears all the connections from the software.""" + self._connections = {} + def send( self, payload: Any, diff --git a/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py new file mode 100644 index 00000000..2828cc25 --- /dev/null +++ b/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py @@ -0,0 +1,107 @@ +from ipaddress import IPv4Address +from typing import Tuple + +import pytest + +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.software import SoftwareHealthState + + +@pytest.fixture(scope="function") +def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseService, Server]: + computer, server = client_server + + # Install DoSBot on computer + computer.software_manager.install(DoSBot) + + dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + dos_bot.configure( + target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address), + target_port=Port.POSTGRES_SERVER, + ) + + # Install FTP Server service on server + server.software_manager.install(DatabaseService) + db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_server_service.start() + + return dos_bot, computer, db_server_service, server + + +def test_repeating_dos_attack(dos_bot_and_db_server): + dos_bot, computer, db_server_service, server = dos_bot_and_db_server + + assert db_server_service.health_state_actual is SoftwareHealthState.GOOD + + dos_bot.port_scan_p_of_success = 1 + dos_bot.repeat = True + dos_bot.run() + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions + + assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED + assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED + + db_server_service.clear_connections() + db_server_service.health_state_actual = SoftwareHealthState.GOOD + assert len(db_server_service.connections) == 0 + + computer.apply_timestep(timestep=1) + server.apply_timestep(timestep=1) + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions + + assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED + assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED + + +def test_non_repeating_dos_attack(dos_bot_and_db_server): + dos_bot, computer, db_server_service, server = dos_bot_and_db_server + + assert db_server_service.health_state_actual is SoftwareHealthState.GOOD + + dos_bot.port_scan_p_of_success = 1 + dos_bot.repeat = False + dos_bot.run() + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions + + assert dos_bot.attack_stage is DoSAttackStage.COMPLETED + assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED + + db_server_service.clear_connections() + db_server_service.health_state_actual = SoftwareHealthState.GOOD + assert len(db_server_service.connections) == 0 + + computer.apply_timestep(timestep=1) + server.apply_timestep(timestep=1) + + assert len(dos_bot.connections) == 0 + assert len(db_server_service.connections) == 0 + assert len(dos_bot.connections) == 0 + + assert dos_bot.attack_stage is DoSAttackStage.COMPLETED + assert db_server_service.health_state_actual is SoftwareHealthState.GOOD + + +def test_dos_bot_database_service_connection(dos_bot_and_db_server): + dos_bot, computer, db_server_service, server = dos_bot_and_db_server + + dos_bot.operating_state = ApplicationOperatingState.RUNNING + dos_bot.attack_stage = DoSAttackStage.PORT_SCAN + dos_bot._perform_dos() + + assert len(dos_bot.connections) == db_server_service.max_sessions + assert len(db_server_service.connections) == db_server_service.max_sessions + assert len(dos_bot.connections) == db_server_service.max_sessions diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py new file mode 100644 index 00000000..71489171 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -0,0 +1,90 @@ +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot + + +@pytest.fixture(scope="function") +def dos_bot() -> DoSBot: + computer = Computer( + hostname="compromised_pc", + ip_address="192.168.0.1", + subnet_mask="255.255.255.0", + operating_state=NodeOperatingState.ON, + ) + + computer.software_manager.install(DoSBot) + + dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) + dos_bot.set_original_state() + return dos_bot + + +def test_dos_bot_creation(dos_bot): + """Test that the DoS bot is installed on a node.""" + assert dos_bot is not None + + +def test_dos_bot_reset(dos_bot): + assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") + assert dos_bot.target_port is Port.POSTGRES_SERVER + assert dos_bot.payload is None + assert dos_bot.repeat is False + + dos_bot.configure( + target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True + ) + + # should reset the relevant items + dos_bot.reset_component_for_episode(episode=0) + assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") + assert dos_bot.target_port is Port.POSTGRES_SERVER + assert dos_bot.payload is None + assert dos_bot.repeat is False + + dos_bot.configure( + target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True + ) + dos_bot.set_original_state() + dos_bot.reset_component_for_episode(episode=1) + # should reset to the configured value + assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") + assert dos_bot.target_port is Port.HTTP + assert dos_bot.payload == "payload" + assert dos_bot.repeat is True + + +def test_dos_bot_cannot_run_when_node_offline(dos_bot): + dos_bot_node: Computer = dos_bot.parent + assert dos_bot_node.operating_state is NodeOperatingState.ON + + dos_bot_node.power_off() + + for i in range(dos_bot_node.shut_down_duration + 1): + dos_bot_node.apply_timestep(timestep=i) + + assert dos_bot_node.operating_state is NodeOperatingState.OFF + + dos_bot._application_loop() + + # assert not run + assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED + + +def test_dos_bot_not_configured(dos_bot): + dos_bot.target_ip_address = None + + dos_bot.operating_state = ApplicationOperatingState.RUNNING + dos_bot._application_loop() + + +def test_dos_bot_perform_port_scan(dos_bot): + dos_bot._perform_port_scan(p_of_success=1) + + assert dos_bot.attack_stage is DoSAttackStage.PORT_SCAN diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py index 15d28d4b..204b356f 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -63,10 +63,11 @@ def test_disconnect_when_client_is_closed(database_client_on_computer): def test_disconnect(database_client_on_computer): - """Database client should set connected to False and remove the database server ip address.""" + """Database client should remove the connection.""" database_client, computer = database_client_on_computer - database_client.connections[uuid4()] = {} + database_client._connections[str(uuid4())] = {"item": True} + assert len(database_client.connections) == 1 assert database_client.operating_state is ApplicationOperatingState.RUNNING assert database_client.server_ip_address is not None @@ -75,6 +76,14 @@ def test_disconnect(database_client_on_computer): assert len(database_client.connections) == 0 + uuid = str(uuid4()) + database_client._connections[uuid] = {"item": True} + assert len(database_client.connections) == 1 + + database_client.disconnect(connection_id=uuid) + + assert len(database_client.connections) == 0 + def test_query_when_client_is_closed(database_client_on_computer): """Database client should return False when it is not running.""" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py index b32463a2..016cf011 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_services.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.software import SoftwareHealthState @@ -66,3 +68,34 @@ def test_enable_disable(service): service.enable() assert service.operating_state == ServiceOperatingState.STOPPED + + +def test_overwhelm_service(service): + service.max_sessions = 2 + service.start() + + uuid = str(uuid4()) + assert service.add_connection(connection_id=uuid) # should be true + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert not service.add_connection(connection_id=uuid) # fails because connection already exists + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert service.add_connection(connection_id=str(uuid4())) # succeed + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity + assert service.health_state_actual is SoftwareHealthState.OVERWHELMED + + +def test_create_and_remove_connections(service): + service.start() + uuid = str(uuid4()) + + assert service.add_connection(connection_id=uuid) # should be true + assert len(service.connections) == 1 + assert service.health_state_actual is SoftwareHealthState.GOOD + + assert service.remove_connection(connection_id=uuid) # should be true + assert len(service.connections) == 0 + assert service.health_state_actual is SoftwareHealthState.GOOD