#2059: moved connection handling from Service to IOSoftware + changes that now utilise connections from IOSoftware + dos bot attacking now works + tests
This commit is contained in:
@@ -5,7 +5,7 @@ from uuid import uuid4
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -23,7 +23,6 @@ class DatabaseClient(Application):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
connections: Dict[str, Dict] = {}
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -143,7 +142,7 @@ class DatabaseClient(Application):
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.connections.pop(connection_id)
|
||||
self.remove_connection(connection_id=connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
|
||||
@@ -181,8 +180,6 @@ class DatabaseClient(Application):
|
||||
def run(self) -> None:
|
||||
"""Run the DatabaseClient."""
|
||||
super().run()
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.connect()
|
||||
|
||||
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
@@ -221,7 +218,8 @@ class DatabaseClient(Application):
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
if payload["response"] is True:
|
||||
self.connections[payload.get("connection_id")] = payload
|
||||
# add connection
|
||||
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
|
||||
elif payload["type"] == "sql":
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Optional
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -177,9 +176,9 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if self.operating_state != ApplicationOperatingState.RUNNING:
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
if self.server_ip_address and self.payload and self.operating_state:
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self._logon()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DoSAttackStage(IntEnum):
|
||||
"""Enum representing the different stages of a Denial of Service attack."""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Attack not yet started."
|
||||
|
||||
PORT_SCAN = 1
|
||||
"Attack is in discovery stage - checking if provided ip and port are open."
|
||||
|
||||
ATTACKING = 2
|
||||
"Denial of Service attack is in progress."
|
||||
|
||||
COMPLETED = 3
|
||||
"Attack is completed."
|
||||
|
||||
|
||||
class DoSBot(DatabaseClient, Application):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of the target service."""
|
||||
|
||||
target_port: Optional[Port] = None
|
||||
"""Port of the target service."""
|
||||
|
||||
payload: Optional[str] = None
|
||||
"""Payload to deliver to the target service as part of the denial of service attack."""
|
||||
|
||||
repeat: bool = False
|
||||
"""If true, the Denial of Service bot will keep performing the attack."""
|
||||
|
||||
attack_stage: DoSAttackStage = DoSAttackStage.NOT_STARTED
|
||||
"""Current stage of the DoS kill chain."""
|
||||
|
||||
port_scan_p_of_success: float = 0.1
|
||||
"""Probability of port scanning being sucessful."""
|
||||
|
||||
dos_intensity: float = 0.25
|
||||
"""How much of the max sessions will be used by the DoS when attacking."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
|
||||
def set_original_state(self):
|
||||
"""Set the original state of the Denial of Service Bot."""
|
||||
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"target_ip_address",
|
||||
"target_port",
|
||||
"payload",
|
||||
"repeat",
|
||||
"attack_stage",
|
||||
"max_sessions",
|
||||
"port_scan_p_of_success",
|
||||
"dos_intensity",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
|
||||
return rm
|
||||
|
||||
def configure(
|
||||
self,
|
||||
target_ip_address: IPv4Address,
|
||||
target_port: Optional[Port] = Port.POSTGRES_SERVER,
|
||||
payload: Optional[str] = None,
|
||||
repeat: bool = False,
|
||||
max_sessions: int = 1000,
|
||||
):
|
||||
"""
|
||||
Configure the Denial of Service bot.
|
||||
|
||||
:param: target_ip_address: The IP address of the Node containing the target service.
|
||||
:param: target_port: The port of the target service. Optional - Default is `Port.HTTP`
|
||||
:param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None`
|
||||
:param: repeat: If True, the bot will maintain the attack. Optional - Default is `True`
|
||||
:param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000
|
||||
"""
|
||||
self.target_ip_address = target_ip_address
|
||||
self.target_port = target_port
|
||||
self.payload = payload
|
||||
self.repeat = repeat
|
||||
self.max_sessions = max_sessions
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, {repeat=}."
|
||||
)
|
||||
|
||||
def run(self):
|
||||
"""Run the Denial of Service Bot."""
|
||||
super().run()
|
||||
self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
"""
|
||||
The main application loop for the Denial of Service bot.
|
||||
|
||||
The loop goes through the stages of a DoS attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
|
||||
# DoS bot cannot do anything without a target
|
||||
if not self.target_ip_address or not self.target_port:
|
||||
self.sys_log.error(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return
|
||||
|
||||
self.clear_connections()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
self._perform_dos()
|
||||
|
||||
if self.repeat and self.attack_stage is DoSAttackStage.ATTACKING:
|
||||
self.attack_stage = DoSAttackStage.NOT_STARTED
|
||||
else:
|
||||
self.attack_stage = DoSAttackStage.COMPLETED
|
||||
|
||||
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
Perform a simulated port scan to check for open SQL ports.
|
||||
|
||||
Advances the attack stage to `PORT_SCAN` if successful.
|
||||
|
||||
:param p_of_success: Probability of successful port scan, by default 0.1.
|
||||
"""
|
||||
if self.attack_stage == DoSAttackStage.NOT_STARTED:
|
||||
# perform a port scan to identify that the SQL port is open on the server
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing port scan")
|
||||
# perform the port scan
|
||||
port_is_open = True # Temporary; later we can implement NMAP port scan.
|
||||
if port_is_open:
|
||||
self.sys_log.info(f"{self.name}: ")
|
||||
self.attack_stage = DoSAttackStage.PORT_SCAN
|
||||
|
||||
def _perform_dos(self):
|
||||
"""
|
||||
Perform the Denial of Service attack.
|
||||
|
||||
DoSBot does this by clogging up the available connections to a service.
|
||||
"""
|
||||
if not self.attack_stage == DoSAttackStage.PORT_SCAN:
|
||||
return
|
||||
self.attack_stage = DoSAttackStage.ATTACKING
|
||||
self.server_ip_address = self.target_ip_address
|
||||
self.port = self.target_port
|
||||
|
||||
dos_sessions = int(float(self.max_sessions) * self.dos_intensity)
|
||||
for i in range(dos_sessions):
|
||||
self.connect()
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the bot, iterate through the application loop.
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
self._application_loop()
|
||||
@@ -45,7 +45,7 @@ class DatabaseService(Service):
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"password",
|
||||
"connections",
|
||||
"_connections",
|
||||
"backup_server",
|
||||
"latest_backup_directory",
|
||||
"latest_backup_file_name",
|
||||
@@ -55,7 +55,7 @@ class DatabaseService(Service):
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
self.connections.clear()
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
@@ -225,9 +225,6 @@ class DatabaseService(Service):
|
||||
:param session_id: The session identifier.
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
if not super().receive(payload=payload, session_id=session_id, **kwargs):
|
||||
return False
|
||||
|
||||
result = {"status_code": 500, "data": []}
|
||||
|
||||
# if server service is down, return error
|
||||
|
||||
@@ -20,9 +20,6 @@ class FTPClient(FTPServiceABC):
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
connected: bool = False
|
||||
"""Keeps track of whether or not the FTP client is connected to an FTP server."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPClient"
|
||||
kwargs["port"] = Port.FTP
|
||||
@@ -129,10 +126,7 @@ class FTPClient(FTPServiceABC):
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
|
||||
)
|
||||
if payload.status_code == FTPStatusCode.OK:
|
||||
self.connected = False
|
||||
return True
|
||||
return False
|
||||
return payload.status_code == FTPStatusCode.OK
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
@@ -179,9 +173,9 @@ class FTPClient(FTPServiceABC):
|
||||
return False
|
||||
|
||||
# check if FTP is currently connected to IP
|
||||
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}")
|
||||
@@ -230,9 +224,9 @@ class FTPClient(FTPServiceABC):
|
||||
:type: dest_port: Optional[Port]
|
||||
"""
|
||||
# check if FTP is currently connected to IP
|
||||
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
# send retrieve request
|
||||
@@ -286,6 +280,14 @@ class FTPClient(FTPServiceABC):
|
||||
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
|
||||
return False
|
||||
|
||||
# if PORT succeeded, add the connection as an active connection list
|
||||
if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK:
|
||||
self.add_connection(connection_id=session_id, session_id=session_id)
|
||||
|
||||
# if QUIT succeeded, remove the session from active connection list
|
||||
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
|
||||
self.remove_connection(connection_id=session_id)
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
|
||||
|
||||
self._process_ftp_command(payload=payload, session_id=session_id)
|
||||
|
||||
@@ -37,7 +37,7 @@ class FTPServer(FTPServiceABC):
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
|
||||
self.connections.clear()
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -42,9 +40,6 @@ class Service(IOSoftware):
|
||||
restart_countdown: Optional[int] = None
|
||||
"If currently restarting, how many timesteps remain until the restart is finished."
|
||||
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections to the Service."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -103,11 +98,6 @@ class Service(IOSoftware):
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
return rm
|
||||
|
||||
@property
|
||||
def connections(self) -> Dict[str, Dict]:
|
||||
"""Return the public version of connections."""
|
||||
return copy.copy(self._connections)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -123,56 +113,6 @@ class Service(IOSoftware):
|
||||
state["health_state_visible"] = self.health_state_visible.value
|
||||
return state
|
||||
|
||||
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create a new connection to this service.
|
||||
|
||||
Returns true if connection successfully created
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
session_details = None
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
self._connections[connection_id] = {
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
def remove_connection(self, connection_id: str) -> bool:
|
||||
"""
|
||||
Remove a connection from this service.
|
||||
|
||||
Returns true if connection successfully removed
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
if self._connections.get(connection_id):
|
||||
self._connections.pop(connection_id)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
|
||||
return True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
@@ -206,6 +208,8 @@ class IOSoftware(Software):
|
||||
"Indicates if the software uses UDP protocol for communication. Default is True."
|
||||
port: Port
|
||||
"The port to which the software is connected."
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
@@ -250,6 +254,65 @@ class IOSoftware(Software):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def connections(self) -> Dict[str, Dict]:
|
||||
"""Return the public version of connections."""
|
||||
return copy.copy(self._connections)
|
||||
|
||||
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create a new connection to this service.
|
||||
|
||||
Returns true if connection successfully created
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
session_details = None
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
self._connections[connection_id] = {
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
def remove_connection(self, connection_id: str) -> bool:
|
||||
"""
|
||||
Remove a connection from this service.
|
||||
|
||||
Returns true if connection successfully removed
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
if self.connections.get(connection_id):
|
||||
self._connections.pop(connection_id)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
|
||||
return True
|
||||
|
||||
def clear_connections(self):
|
||||
"""Clears all the connections from the software."""
|
||||
self._connections = {}
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: Any,
|
||||
|
||||
Reference in New Issue
Block a user