Merged PR 234: #2059: Denial of Service Bot
## Summary - Moved DataManipulationBot into red applications - these are applications, not services - moved the connection handling from Service base class to the IOSoftware base class - Applications and Services can track the connections they make - increased default max sessions to 100 - made sure the services/applications that are dependent on connections use the IOSoftware connections - DoSBot follows some sort of kill chain, although at the moment it just: - runs a port scan with a 10% success chance (by default) - runs a DoS attack that fills a service's max sessions ## Test process unit test in https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/234?path=/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py&_a=files integration test in https://dev.azure.com/ma-dev-uk/PrimAITE/_git/PrimAITE/pullrequest/234?path=/tests/integration_tests/system/red_applications/test_dos_bot_and_server.py&_a=files ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [X] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2059
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Dict, List, Tuple
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
|
||||
@@ -20,13 +20,13 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -314,6 +314,7 @@ class PrimaiteGame:
|
||||
opt = application_cfg["options"]
|
||||
new_application.configure(
|
||||
server_ip_address=IPv4Address(opt.get("server_ip")),
|
||||
server_password=opt.get("server_password"),
|
||||
payload=opt.get("payload"),
|
||||
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
|
||||
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
|
||||
|
||||
@@ -9,10 +9,10 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
|
||||
@@ -252,9 +252,9 @@ def arcd_uc2_network() -> Network:
|
||||
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
|
||||
database_service.start()
|
||||
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
|
||||
database_service._process_sql(ddl, None) # noqa
|
||||
database_service._process_sql(ddl, None, None) # noqa
|
||||
for insert_statement in user_insert_statements:
|
||||
database_service._process_sql(insert_statement, None) # noqa
|
||||
database_service._process_sql(insert_statement, None, None) # noqa
|
||||
|
||||
# Web Server
|
||||
web_server = Server(
|
||||
|
||||
@@ -5,7 +5,7 @@ from uuid import uuid4
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -23,7 +23,6 @@ class DatabaseClient(Application):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
connected: bool = False
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -66,18 +65,24 @@ class DatabaseClient(Application):
|
||||
self.server_password = server_password
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self) -> bool:
|
||||
def connect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not self.connected:
|
||||
return self._connect(self.server_ip_address, self.server_password)
|
||||
# already connected
|
||||
return True
|
||||
if not connection_id:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id
|
||||
)
|
||||
|
||||
def _connect(
|
||||
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
connection_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Connects the DatabaseClient to the DatabaseServer.
|
||||
@@ -92,33 +97,58 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self.connected:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
|
||||
if self.connections.get(connection_id):
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
)
|
||||
self.server_ip_address = server_ip_address
|
||||
return self.connected
|
||||
return True
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
|
||||
)
|
||||
return False
|
||||
payload = {"type": "connect_request", "password": password}
|
||||
payload = {
|
||||
"type": "connect_request",
|
||||
"password": password,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(server_ip_address, password, True)
|
||||
return self._connect(
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
)
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self, connection_id: Optional[str] = None) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
)
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}")
|
||||
self.server_ip_address = None
|
||||
self.connected = False
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not len(self.connections):
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
return False
|
||||
|
||||
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
|
||||
# if no connection provided, disconnect the first connection
|
||||
if not connection_id:
|
||||
connection_id = list(self.connections.keys())[0]
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.remove_connection(connection_id=connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected connection {connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
|
||||
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
|
||||
"""
|
||||
Send a query to the connected database server.
|
||||
|
||||
@@ -141,19 +171,17 @@ class DatabaseClient(Application):
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id},
|
||||
payload={"type": "sql", "sql": sql, "uuid": query_id, "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
|
||||
return self._query(sql=sql, query_id=query_id, connection_id=connection_id, is_reattempt=True)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the DatabaseClient."""
|
||||
super().run()
|
||||
if self.operating_state == ApplicationOperatingState.RUNNING:
|
||||
self.connect()
|
||||
|
||||
def query(self, sql: str, is_reattempt: bool = False) -> bool:
|
||||
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Send a query to the Database Service.
|
||||
|
||||
@@ -164,20 +192,17 @@ class DatabaseClient(Application):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if self.connected:
|
||||
query_id = str(uuid4())
|
||||
if connection_id is None:
|
||||
connection_id = str(uuid4())
|
||||
|
||||
if not self.connections.get(connection_id):
|
||||
if not self.connect(connection_id=connection_id):
|
||||
return False
|
||||
|
||||
# Initialise the tracker of this ID to False
|
||||
self._query_success_tracker[query_id] = False
|
||||
return self._query(sql=sql, query_id=query_id)
|
||||
else:
|
||||
if is_reattempt:
|
||||
return False
|
||||
|
||||
if not self.connect():
|
||||
return False
|
||||
|
||||
self.query(sql=sql, is_reattempt=True)
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
@@ -192,13 +217,13 @@ class DatabaseClient(Application):
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
self.connected = payload["response"] == True
|
||||
if payload["response"] is True:
|
||||
# add connection
|
||||
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
|
||||
elif payload["type"] == "sql":
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
else:
|
||||
self.connected = False
|
||||
return True
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Optional
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -149,9 +148,9 @@ class DataManipulationBot(DatabaseClient):
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing data manipulation")
|
||||
# perform the attack
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
self.connect()
|
||||
if self.connected:
|
||||
if len(self.connections):
|
||||
self.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
@@ -177,9 +176,9 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if self.operating_state != ApplicationOperatingState.RUNNING:
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
if self.server_ip_address and self.payload and self.operating_state:
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self._logon()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
@@ -0,0 +1,192 @@
|
||||
from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class DoSAttackStage(IntEnum):
|
||||
"""Enum representing the different stages of a Denial of Service attack."""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Attack not yet started."
|
||||
|
||||
PORT_SCAN = 1
|
||||
"Attack is in discovery stage - checking if provided ip and port are open."
|
||||
|
||||
ATTACKING = 2
|
||||
"Denial of Service attack is in progress."
|
||||
|
||||
COMPLETED = 3
|
||||
"Attack is completed."
|
||||
|
||||
|
||||
class DoSBot(DatabaseClient, Application):
|
||||
"""A bot that simulates a Denial of Service attack."""
|
||||
|
||||
target_ip_address: Optional[IPv4Address] = None
|
||||
"""IP address of the target service."""
|
||||
|
||||
target_port: Optional[Port] = None
|
||||
"""Port of the target service."""
|
||||
|
||||
payload: Optional[str] = None
|
||||
"""Payload to deliver to the target service as part of the denial of service attack."""
|
||||
|
||||
repeat: bool = False
|
||||
"""If true, the Denial of Service bot will keep performing the attack."""
|
||||
|
||||
attack_stage: DoSAttackStage = DoSAttackStage.NOT_STARTED
|
||||
"""Current stage of the DoS kill chain."""
|
||||
|
||||
port_scan_p_of_success: float = 0.1
|
||||
"""Probability of port scanning being sucessful."""
|
||||
|
||||
dos_intensity: float = 1.0
|
||||
"""How much of the max sessions will be used by the DoS when attacking."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DoSBot"
|
||||
self.max_sessions = 1000 # override normal max sessions
|
||||
|
||||
def set_original_state(self):
|
||||
"""Set the original state of the Denial of Service Bot."""
|
||||
_LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}")
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"target_ip_address",
|
||||
"target_port",
|
||||
"payload",
|
||||
"repeat",
|
||||
"attack_stage",
|
||||
"max_sessions",
|
||||
"port_scan_p_of_success",
|
||||
"dos_intensity",
|
||||
}
|
||||
self._original_state.update(self.model_dump(include=vals_to_include))
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}")
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
|
||||
return rm
|
||||
|
||||
def configure(
|
||||
self,
|
||||
target_ip_address: IPv4Address,
|
||||
target_port: Optional[Port] = Port.POSTGRES_SERVER,
|
||||
payload: Optional[str] = None,
|
||||
repeat: bool = False,
|
||||
port_scan_p_of_success: float = 0.1,
|
||||
dos_intensity: float = 1.0,
|
||||
max_sessions: int = 1000,
|
||||
):
|
||||
"""
|
||||
Configure the Denial of Service bot.
|
||||
|
||||
:param: target_ip_address: The IP address of the Node containing the target service.
|
||||
:param: target_port: The port of the target service. Optional - Default is `Port.HTTP`
|
||||
:param: payload: The payload the DoS Bot will throw at the target service. Optional - Default is `None`
|
||||
:param: repeat: If True, the bot will maintain the attack. Optional - Default is `True`
|
||||
:param: port_scan_p_of_success: The chance of the port scan being sucessful. Optional - Default is 0.1 (10%)
|
||||
:param: dos_intensity: The intensity of the DoS attack.
|
||||
Multiplied with the application's max session - Default is 1.0
|
||||
:param: max_sessions: The maximum number of sessions the DoS bot will attack with. Optional - Default is 1000
|
||||
"""
|
||||
self.target_ip_address = target_ip_address
|
||||
self.target_port = target_port
|
||||
self.payload = payload
|
||||
self.repeat = repeat
|
||||
self.port_scan_p_of_success = port_scan_p_of_success
|
||||
self.dos_intensity = dos_intensity
|
||||
self.max_sessions = max_sessions
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Configured the {self.name} with {target_ip_address=}, {target_port=}, {payload=}, "
|
||||
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
|
||||
)
|
||||
|
||||
def run(self):
|
||||
"""Run the Denial of Service Bot."""
|
||||
super().run()
|
||||
self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
"""
|
||||
The main application loop for the Denial of Service bot.
|
||||
|
||||
The loop goes through the stages of a DoS attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
|
||||
# DoS bot cannot do anything without a target
|
||||
if not self.target_ip_address or not self.target_port:
|
||||
self.sys_log.error(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return
|
||||
|
||||
self.clear_connections()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
self._perform_dos()
|
||||
|
||||
if self.repeat and self.attack_stage is DoSAttackStage.ATTACKING:
|
||||
self.attack_stage = DoSAttackStage.NOT_STARTED
|
||||
else:
|
||||
self.attack_stage = DoSAttackStage.COMPLETED
|
||||
|
||||
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
Perform a simulated port scan to check for open SQL ports.
|
||||
|
||||
Advances the attack stage to `PORT_SCAN` if successful.
|
||||
|
||||
:param p_of_success: Probability of successful port scan, by default 0.1.
|
||||
"""
|
||||
if self.attack_stage == DoSAttackStage.NOT_STARTED:
|
||||
# perform a port scan to identify that the SQL port is open on the server
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing port scan")
|
||||
# perform the port scan
|
||||
port_is_open = True # Temporary; later we can implement NMAP port scan.
|
||||
if port_is_open:
|
||||
self.sys_log.info(f"{self.name}: ")
|
||||
self.attack_stage = DoSAttackStage.PORT_SCAN
|
||||
|
||||
def _perform_dos(self):
|
||||
"""
|
||||
Perform the Denial of Service attack.
|
||||
|
||||
DoSBot does this by clogging up the available connections to a service.
|
||||
"""
|
||||
if not self.attack_stage == DoSAttackStage.PORT_SCAN:
|
||||
return
|
||||
self.attack_stage = DoSAttackStage.ATTACKING
|
||||
self.server_ip_address = self.target_ip_address
|
||||
self.port = self.target_port
|
||||
|
||||
dos_sessions = int(float(self.max_sessions) * self.dos_intensity)
|
||||
for i in range(dos_sessions):
|
||||
self.connect()
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the bot, iterate through the application loop.
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
self._application_loop()
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
@@ -22,7 +21,6 @@ class DatabaseService(Service):
|
||||
"""
|
||||
|
||||
password: Optional[str] = None
|
||||
connections: Dict[str, datetime] = {}
|
||||
|
||||
backup_server: IPv4Address = None
|
||||
"""IP address of the backup server."""
|
||||
@@ -47,7 +45,7 @@ class DatabaseService(Service):
|
||||
super().set_original_state()
|
||||
vals_to_include = {
|
||||
"password",
|
||||
"connections",
|
||||
"_connections",
|
||||
"backup_server",
|
||||
"latest_backup_directory",
|
||||
"latest_backup_file_name",
|
||||
@@ -57,7 +55,7 @@ class DatabaseService(Service):
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
|
||||
self.connections.clear()
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def configure_backup(self, backup_server: IPv4Address):
|
||||
@@ -140,24 +138,39 @@ class DatabaseService(Service):
|
||||
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
|
||||
|
||||
def _process_connect(
|
||||
self, session_id: str, password: Optional[str] = None
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
status_code = 500 # Default internal server error
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
|
||||
)
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
self.connections[session_id] = datetime.now()
|
||||
self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised")
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id):
|
||||
status_code = 500
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined")
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
|
||||
return {
|
||||
"status_code": status_code,
|
||||
"type": "connect_response",
|
||||
"response": status_code == 200,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
|
||||
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
|
||||
def _process_sql(
|
||||
self, query: Literal["SELECT", "DELETE"], query_id: str, connection_id: Optional[str] = None
|
||||
) -> Dict[str, Union[int, List[Any]]]:
|
||||
"""
|
||||
Executes the given SQL query and returns the result.
|
||||
|
||||
@@ -169,15 +182,28 @@ class DatabaseService(Service):
|
||||
:return: Dictionary containing status code and data fetched.
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
|
||||
if query == "SELECT":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": True,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
elif query == "DELETE":
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
self.health_state_actual = SoftwareHealthState.COMPROMISED
|
||||
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
|
||||
return {
|
||||
"status_code": 200,
|
||||
"type": "sql",
|
||||
"data": False,
|
||||
"uuid": query_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
else:
|
||||
return {"status_code": 404, "data": False}
|
||||
else:
|
||||
@@ -203,19 +229,25 @@ class DatabaseService(Service):
|
||||
:param session_id: The session identifier.
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
if not super().receive(payload=payload, session_id=session_id, **kwargs):
|
||||
result = {"status_code": 500, "data": []}
|
||||
|
||||
# if server service is down, return error
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
result = {"status_code": 500, "data": []}
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
result = self._process_connect(session_id=session_id, password=payload.get("password"))
|
||||
result = self._process_connect(
|
||||
connection_id=payload.get("connection_id"), password=payload.get("password")
|
||||
)
|
||||
elif payload["type"] == "disconnect":
|
||||
if session_id in self.connections:
|
||||
self.connections.pop(session_id)
|
||||
if payload["connection_id"] in self.connections:
|
||||
self.remove_connection(connection_id=payload["connection_id"])
|
||||
elif payload["type"] == "sql":
|
||||
if session_id in self.connections:
|
||||
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
|
||||
if payload.get("connection_id") in self.connections:
|
||||
result = self._process_sql(
|
||||
query=payload["sql"], query_id=payload["uuid"], connection_id=payload["connection_id"]
|
||||
)
|
||||
else:
|
||||
result = {"status_code": 401, "type": "sql"}
|
||||
self.send(payload=result, session_id=session_id)
|
||||
|
||||
@@ -20,9 +20,6 @@ class FTPClient(FTPServiceABC):
|
||||
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
|
||||
"""
|
||||
|
||||
connected: bool = False
|
||||
"""Keeps track of whether or not the FTP client is connected to an FTP server."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPClient"
|
||||
kwargs["port"] = Port.FTP
|
||||
@@ -129,10 +126,7 @@ class FTPClient(FTPServiceABC):
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
|
||||
)
|
||||
if payload.status_code == FTPStatusCode.OK:
|
||||
self.connected = False
|
||||
return True
|
||||
return False
|
||||
return payload.status_code == FTPStatusCode.OK
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
@@ -179,9 +173,9 @@ class FTPClient(FTPServiceABC):
|
||||
return False
|
||||
|
||||
# check if FTP is currently connected to IP
|
||||
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}")
|
||||
@@ -230,9 +224,9 @@ class FTPClient(FTPServiceABC):
|
||||
:type: dest_port: Optional[Port]
|
||||
"""
|
||||
# check if FTP is currently connected to IP
|
||||
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
|
||||
|
||||
if not self.connected:
|
||||
if not len(self.connections):
|
||||
return False
|
||||
else:
|
||||
# send retrieve request
|
||||
@@ -286,6 +280,14 @@ class FTPClient(FTPServiceABC):
|
||||
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
|
||||
return False
|
||||
|
||||
# if PORT succeeded, add the connection as an active connection list
|
||||
if payload.ftp_command is FTPCommand.PORT and payload.status_code is FTPStatusCode.OK:
|
||||
self.add_connection(connection_id=session_id, session_id=session_id)
|
||||
|
||||
# if QUIT succeeded, remove the session from active connection list
|
||||
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
|
||||
self.remove_connection(connection_id=session_id)
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
|
||||
|
||||
self._process_ftp_command(payload=payload, session_id=session_id)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
@@ -21,9 +20,6 @@ class FTPServer(FTPServiceABC):
|
||||
server_password: Optional[str] = None
|
||||
"""Password needed to connect to FTP server. Default is None."""
|
||||
|
||||
connections: Dict[str, IPv4Address] = {}
|
||||
"""Current active connections to the FTP server."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "FTPServer"
|
||||
kwargs["port"] = Port.FTP
|
||||
@@ -41,7 +37,7 @@ class FTPServer(FTPServiceABC):
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""Reset the original state of the SimComponent."""
|
||||
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
|
||||
self.connections.clear()
|
||||
self.clear_connections()
|
||||
super().reset_component_for_episode(episode)
|
||||
|
||||
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
|
||||
@@ -62,9 +58,6 @@ class FTPServer(FTPServiceABC):
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
|
||||
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
|
||||
if payload.ftp_command is not None:
|
||||
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
|
||||
|
||||
@@ -73,7 +66,7 @@ class FTPServer(FTPServiceABC):
|
||||
# check that the port is valid
|
||||
if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535):
|
||||
# return successful connection
|
||||
self.connections[session_id] = session_details.with_ip_address
|
||||
self.add_connection(connection_id=session_id, session_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
@@ -81,7 +74,7 @@ class FTPServer(FTPServiceABC):
|
||||
return payload
|
||||
|
||||
if payload.ftp_command == FTPCommand.QUIT:
|
||||
self.connections.pop(session_id)
|
||||
self.remove_connection(connection_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
|
||||
@@ -40,6 +40,12 @@ class Service(IOSoftware):
|
||||
restart_countdown: Optional[int] = None
|
||||
"If currently restarting, how many timesteps remain until the restart is finished."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
@@ -52,7 +58,7 @@ class Service(IOSoftware):
|
||||
if not super()._can_perform_action():
|
||||
return False
|
||||
|
||||
if self.operating_state is not self.operating_state.RUNNING:
|
||||
if self.operating_state is not ServiceOperatingState.RUNNING:
|
||||
# service is not running
|
||||
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
@@ -74,12 +80,6 @@ class Service(IOSoftware):
|
||||
"""
|
||||
return super().receive(payload=payload, session_id=session_id, **kwargs)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.health_state_visible = SoftwareHealthState.UNUSED
|
||||
self.health_state_actual = SoftwareHealthState.UNUSED
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
super().set_original_state()
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
@@ -198,7 +200,7 @@ class IOSoftware(Software):
|
||||
|
||||
installing_count: int = 0
|
||||
"The number of times the software has been installed. Default is 0."
|
||||
max_sessions: int = 1
|
||||
max_sessions: int = 100
|
||||
"The maximum number of sessions that the software can handle simultaneously. Default is 0."
|
||||
tcp: bool = True
|
||||
"Indicates if the software uses TCP protocol for communication. Default is True."
|
||||
@@ -206,6 +208,8 @@ class IOSoftware(Software):
|
||||
"Indicates if the software uses UDP protocol for communication. Default is True."
|
||||
port: Port
|
||||
"The port to which the software is connected."
|
||||
_connections: Dict[str, Dict] = {}
|
||||
"Active connections."
|
||||
|
||||
def set_original_state(self):
|
||||
"""Sets the original state."""
|
||||
@@ -250,6 +254,65 @@ class IOSoftware(Software):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def connections(self) -> Dict[str, Dict]:
|
||||
"""Return the public version of connections."""
|
||||
return copy.copy(self._connections)
|
||||
|
||||
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create a new connection to this service.
|
||||
|
||||
Returns true if connection successfully created
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.health_state_actual = SoftwareHealthState.OVERWHELMED
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
|
||||
# check that connection already doesn't exist
|
||||
if not self._connections.get(connection_id):
|
||||
session_details = None
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
self._connections[connection_id] = {
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
def remove_connection(self, connection_id: str) -> bool:
|
||||
"""
|
||||
Remove a connection from this service.
|
||||
|
||||
Returns true if connection successfully removed
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:type: string
|
||||
"""
|
||||
if self.connections.get(connection_id):
|
||||
self._connections.pop(connection_id)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
|
||||
return True
|
||||
|
||||
def clear_connections(self):
|
||||
"""Clears all the connections from the software."""
|
||||
self._connections = {}
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: Any,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
|
||||
|
||||
def test_data_manipulation(uc2_network):
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dos_bot_and_db_server(client_server) -> Tuple[DoSBot, Computer, DatabaseService, Server]:
|
||||
computer, server = client_server
|
||||
|
||||
# Install DoSBot on computer
|
||||
computer.software_manager.install(DoSBot)
|
||||
|
||||
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
|
||||
target_port=Port.POSTGRES_SERVER,
|
||||
)
|
||||
|
||||
# Install DB Server service on server
|
||||
server.software_manager.install(DatabaseService)
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
db_server_service.start()
|
||||
|
||||
return dos_bot, computer, db_server_service, server
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dos_bot_db_server_green_client(example_network) -> Network:
|
||||
network: Network = example_network
|
||||
|
||||
router_1: Router = example_network.get_node_by_hostname("router_1")
|
||||
router_1.acl.add_rule(
|
||||
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
|
||||
)
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
|
||||
# install DoS bot on client 1
|
||||
client_1.software_manager.install(DoSBot)
|
||||
|
||||
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address),
|
||||
target_port=Port.POSTGRES_SERVER,
|
||||
)
|
||||
|
||||
# install db server service on server
|
||||
server.software_manager.install(DatabaseService)
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
db_server_service.start()
|
||||
|
||||
# Install DB client (green) on client 2
|
||||
client_2.software_manager.install(DatabaseClient)
|
||||
|
||||
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
|
||||
database_client.run()
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def test_repeating_dos_attack(dos_bot_and_db_server):
|
||||
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
dos_bot.port_scan_p_of_success = 1
|
||||
dos_bot.repeat = True
|
||||
dos_bot.run()
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
db_server_service.clear_connections()
|
||||
db_server_service.health_state_actual = SoftwareHealthState.GOOD
|
||||
assert len(db_server_service.connections) == 0
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
|
||||
def test_non_repeating_dos_attack(dos_bot_and_db_server):
|
||||
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
dos_bot.port_scan_p_of_success = 1
|
||||
dos_bot.repeat = False
|
||||
dos_bot.run()
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
db_server_service.clear_connections()
|
||||
db_server_service.health_state_actual = SoftwareHealthState.GOOD
|
||||
assert len(db_server_service.connections) == 0
|
||||
|
||||
computer.apply_timestep(timestep=1)
|
||||
server.apply_timestep(timestep=1)
|
||||
|
||||
assert len(dos_bot.connections) == 0
|
||||
assert len(db_server_service.connections) == 0
|
||||
assert len(dos_bot.connections) == 0
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_dos_bot_database_service_connection(dos_bot_and_db_server):
|
||||
dos_bot, computer, db_server_service, server = dos_bot_and_db_server
|
||||
|
||||
dos_bot.operating_state = ApplicationOperatingState.RUNNING
|
||||
dos_bot.attack_stage = DoSAttackStage.PORT_SCAN
|
||||
dos_bot._perform_dos()
|
||||
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
|
||||
|
||||
def test_dos_blocks_green_agent_connection(dos_bot_db_server_green_client):
|
||||
network: Network = dos_bot_db_server_green_client
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot")
|
||||
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
dos_bot.port_scan_p_of_success = 1
|
||||
dos_bot.repeat = False
|
||||
dos_bot.run()
|
||||
|
||||
# DoS bot fills up connection of db server service
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(db_server_service.connections) == db_server_service.max_sessions
|
||||
assert len(dos_bot.connections) == db_server_service.max_sessions
|
||||
assert len(green_db_client.connections) == 0
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.COMPLETED
|
||||
# db server service is overwhelmed
|
||||
assert db_server_service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
# green agent tries to connect but fails because service is overwhelmed
|
||||
assert green_db_client.connect() is False
|
||||
assert len(green_db_client.connections) == 0
|
||||
@@ -1,6 +1,9 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
@@ -8,57 +11,109 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
def test_database_client_server_connection(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
@pytest.fixture(scope="function")
|
||||
def peer_to_peer() -> Tuple[Node, Node]:
|
||||
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
|
||||
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
|
||||
node_a.connect_nic(nic_a)
|
||||
node_a.software_manager.get_open_ports()
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
|
||||
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
|
||||
node_b.connect_nic(nic_b)
|
||||
|
||||
Link(endpoint_a=nic_a, endpoint_b=nic_b)
|
||||
|
||||
assert node_a.ping("192.168.0.11")
|
||||
|
||||
node_a.software_manager.install(DatabaseClient)
|
||||
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
|
||||
node_a.software_manager.software["DatabaseClient"].run()
|
||||
|
||||
node_b.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
|
||||
database_service.start()
|
||||
return node_a, node_b
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def peer_to_peer_secure_db() -> Tuple[Node, Node]:
|
||||
node_a = Node(hostname="node_a", operating_state=NodeOperatingState.ON)
|
||||
nic_a = NIC(ip_address="192.168.0.10", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON)
|
||||
node_a.connect_nic(nic_a)
|
||||
node_a.software_manager.get_open_ports()
|
||||
|
||||
node_b = Node(hostname="node_b", operating_state=NodeOperatingState.ON)
|
||||
nic_b = NIC(ip_address="192.168.0.11", subnet_mask="255.255.255.0")
|
||||
node_b.connect_nic(nic_b)
|
||||
|
||||
Link(endpoint_a=nic_a, endpoint_b=nic_b)
|
||||
|
||||
assert node_a.ping("192.168.0.11")
|
||||
|
||||
node_a.software_manager.install(DatabaseClient)
|
||||
node_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
|
||||
node_a.software_manager.software["DatabaseClient"].run()
|
||||
|
||||
node_b.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
|
||||
database_service.password = "12345"
|
||||
database_service.start()
|
||||
return node_a, node_b
|
||||
|
||||
|
||||
def test_database_client_server_connection(peer_to_peer):
|
||||
node_a, node_b = peer_to_peer
|
||||
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
|
||||
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 1
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
db_client.disconnect()
|
||||
assert len(db_client.connections) == 0
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
|
||||
def test_database_client_server_correct_password(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
def test_database_client_server_correct_password(peer_to_peer_secure_db):
|
||||
node_a, node_b = peer_to_peer_secure_db
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_client.disconnect()
|
||||
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345")
|
||||
db_service.password = "12345"
|
||||
|
||||
assert db_client.connect()
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
|
||||
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345")
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 1
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
|
||||
def test_database_client_server_incorrect_password(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
def test_database_client_server_incorrect_password(peer_to_peer_secure_db):
|
||||
node_a, node_b = peer_to_peer_secure_db
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_client.disconnect()
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
|
||||
db_service.password = "12345"
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
|
||||
|
||||
assert not db_client.connect()
|
||||
# should fail
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 0
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="wrongpass")
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 0
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
|
||||
def test_database_client_query(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
assert db_client.connected
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client.connect()
|
||||
|
||||
assert db_client.query("SELECT")
|
||||
|
||||
@@ -66,13 +121,13 @@ def test_database_client_query(uc2_network):
|
||||
def test_create_database_backup(uc2_network):
|
||||
"""Run the backup_database method and check if the FTP server has the relevant file."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
# back up should be created
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
|
||||
ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer")
|
||||
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
|
||||
|
||||
# backup file should exist in the backup server
|
||||
assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None
|
||||
@@ -81,7 +136,7 @@ def test_create_database_backup(uc2_network):
|
||||
def test_restore_backup(uc2_network):
|
||||
"""Run the restore_backup method and check if the backup is properly restored."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
# create a back up
|
||||
assert db_service.backup_database() is True
|
||||
@@ -107,7 +162,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
assert db_client.connected
|
||||
assert len(db_client.connections)
|
||||
|
||||
assert db_client.query("SELECT") is True
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import (
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
|
||||
DataManipulationAttackStage,
|
||||
DataManipulationBot,
|
||||
)
|
||||
@@ -70,4 +70,4 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
|
||||
dm_bot._perform_data_manipulation(p_of_success=1.0)
|
||||
|
||||
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
|
||||
assert dm_bot.connected
|
||||
assert len(dm_bot.connections)
|
||||
@@ -0,0 +1,90 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSAttackStage, DoSBot
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dos_bot() -> DoSBot:
|
||||
computer = Computer(
|
||||
hostname="compromised_pc",
|
||||
ip_address="192.168.0.1",
|
||||
subnet_mask="255.255.255.0",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
|
||||
computer.software_manager.install(DoSBot)
|
||||
|
||||
dos_bot: DoSBot = computer.software_manager.software.get("DoSBot")
|
||||
dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1"))
|
||||
dos_bot.set_original_state()
|
||||
return dos_bot
|
||||
|
||||
|
||||
def test_dos_bot_creation(dos_bot):
|
||||
"""Test that the DoS bot is installed on a node."""
|
||||
assert dos_bot is not None
|
||||
|
||||
|
||||
def test_dos_bot_reset(dos_bot):
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
|
||||
# should reset the relevant items
|
||||
dos_bot.reset_component_for_episode(episode=0)
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.0.1")
|
||||
assert dos_bot.target_port is Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload is None
|
||||
assert dos_bot.repeat is False
|
||||
|
||||
dos_bot.configure(
|
||||
target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True
|
||||
)
|
||||
dos_bot.set_original_state()
|
||||
dos_bot.reset_component_for_episode(episode=1)
|
||||
# should reset to the configured value
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.1.1")
|
||||
assert dos_bot.target_port is Port.HTTP
|
||||
assert dos_bot.payload == "payload"
|
||||
assert dos_bot.repeat is True
|
||||
|
||||
|
||||
def test_dos_bot_cannot_run_when_node_offline(dos_bot):
|
||||
dos_bot_node: Computer = dos_bot.parent
|
||||
assert dos_bot_node.operating_state is NodeOperatingState.ON
|
||||
|
||||
dos_bot_node.power_off()
|
||||
|
||||
for i in range(dos_bot_node.shut_down_duration + 1):
|
||||
dos_bot_node.apply_timestep(timestep=i)
|
||||
|
||||
assert dos_bot_node.operating_state is NodeOperatingState.OFF
|
||||
|
||||
dos_bot._application_loop()
|
||||
|
||||
# assert not run
|
||||
assert dos_bot.attack_stage is DoSAttackStage.NOT_STARTED
|
||||
|
||||
|
||||
def test_dos_bot_not_configured(dos_bot):
|
||||
dos_bot.target_ip_address = None
|
||||
|
||||
dos_bot.operating_state = ApplicationOperatingState.RUNNING
|
||||
dos_bot._application_loop()
|
||||
|
||||
|
||||
def test_dos_bot_perform_port_scan(dos_bot):
|
||||
dos_bot._perform_port_scan(p_of_success=1)
|
||||
|
||||
assert dos_bot.attack_stage is DoSAttackStage.PORT_SCAN
|
||||
@@ -1,5 +1,6 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -62,18 +63,26 @@ def test_disconnect_when_client_is_closed(database_client_on_computer):
|
||||
|
||||
|
||||
def test_disconnect(database_client_on_computer):
|
||||
"""Database client should set connected to False and remove the database server ip address."""
|
||||
"""Database client should remove the connection."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.connected = True
|
||||
database_client._connections[str(uuid4())] = {"item": True}
|
||||
assert len(database_client.connections) == 1
|
||||
|
||||
assert database_client.operating_state is ApplicationOperatingState.RUNNING
|
||||
assert database_client.server_ip_address is not None
|
||||
|
||||
database_client.disconnect()
|
||||
|
||||
assert database_client.connected is False
|
||||
assert database_client.server_ip_address is None
|
||||
assert len(database_client.connections) == 0
|
||||
|
||||
uuid = str(uuid4())
|
||||
database_client._connections[uuid] = {"item": True}
|
||||
assert len(database_client.connections) == 1
|
||||
|
||||
database_client.disconnect(connection_id=uuid)
|
||||
|
||||
assert len(database_client.connections) == 0
|
||||
|
||||
|
||||
def test_query_when_client_is_closed(database_client_on_computer):
|
||||
@@ -86,19 +95,6 @@ def test_query_when_client_is_closed(database_client_on_computer):
|
||||
assert database_client.query(sql="test") is False
|
||||
|
||||
|
||||
def test_query_failed_reattempt(database_client_on_computer):
|
||||
"""Database client query should return False if the reattempt fails."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
def return_false():
|
||||
return False
|
||||
|
||||
database_client.connect = return_false
|
||||
|
||||
database_client.connected = False
|
||||
assert database_client.query(sql="test", is_reattempt=True) is False
|
||||
|
||||
|
||||
def test_query_fail_to_connect(database_client_on_computer):
|
||||
"""Database client query should return False if the connect attempt fails."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
@@ -66,3 +68,34 @@ def test_enable_disable(service):
|
||||
|
||||
service.enable()
|
||||
assert service.operating_state == ServiceOperatingState.STOPPED
|
||||
|
||||
|
||||
def test_overwhelm_service(service):
|
||||
service.max_sessions = 2
|
||||
service.start()
|
||||
|
||||
uuid = str(uuid4())
|
||||
assert service.add_connection(connection_id=uuid) # should be true
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
assert not service.add_connection(connection_id=uuid) # fails because connection already exists
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
assert service.add_connection(connection_id=str(uuid4())) # succeed
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
assert not service.add_connection(connection_id=str(uuid4())) # fail because at capacity
|
||||
assert service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
|
||||
def test_create_and_remove_connections(service):
|
||||
service.start()
|
||||
uuid = str(uuid4())
|
||||
|
||||
assert service.add_connection(connection_id=uuid) # should be true
|
||||
assert len(service.connections) == 1
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
assert service.remove_connection(connection_id=uuid) # should be true
|
||||
assert len(service.connections) == 0
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
Reference in New Issue
Block a user