From 80158fd9b4e1b2beeff3c42843c1f50b0dbc6716 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 11:18:06 +0000 Subject: [PATCH] Make db manipulation bot work with db client --- src/primaite/game/game.py | 6 +-- .../network/transmission/network_layer.py | 2 + .../system/applications/database_client.py | 9 +++- .../red_applications/data_manipulation_bot.py | 48 +++++++++++++++---- 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cf21dd40..10c02b39 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -409,9 +409,6 @@ class PrimaiteGame: # CREATE REWARD FUNCTION reward_function = RewardFunction.from_config(reward_function_cfg) - # OTHER AGENT SETTINGS - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - # CREATE AGENT if agent_type == "probabilistic_agent": # TODO: implement non-random agents and fix this parsing @@ -424,6 +421,7 @@ class PrimaiteGame: settings=settings, ) elif agent_type == "ProxyAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, @@ -433,6 +431,8 @@ class PrimaiteGame: ) game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) + new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], action_space=action_space, diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index dc848ade..22d7f97d 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -15,6 +15,8 @@ class IPProtocol(Enum): .. _List of IPProtocols: """ + NONE = "none" + """Placeholder for a non-port.""" TCP = "tcp" """Transmission Control Protocol.""" UDP = "udp" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 69065225..a8eac196 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -84,7 +84,14 @@ class DatabaseClient(Application): ) return self.connected - def check_connection(self, connection_id:str) -> bool: + def check_connection(self, connection_id: str) -> bool: + """Check whether the connection can be successfully re-established. + + :param connection_id: connection ID to check + :type connection_id: str + :return: Whether the connection was successfully re-established. + :rtype: bool + """ if not self._can_perform_action(): return False print(self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 5fe951b7..11eb71f5 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -1,10 +1,13 @@ from enum import IntEnum from ipaddress import IPv4Address -from typing import Optional +from typing import Dict, Optional from primaite import getLogger from primaite.game.science import simulate_trial from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -32,12 +35,12 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(DatabaseClient): +class DataManipulationBot(Application): """A bot that simulates a script which performs a SQL injection attack.""" - server_ip_address: Optional[IPv4Address] = None + # server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None - server_password: Optional[str] = None + # server_password: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 @@ -46,8 +49,31 @@ class DataManipulationBot(DatabaseClient): "Whether to repeat attacking once finished." def __init__(self, **kwargs): + kwargs["name"] = "DataManipulationBot" + kwargs["port"] = Port.NONE + kwargs["protocol"] = IPProtocol.NONE + super().__init__(**kwargs) - self.name = "DataManipulationBot" + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + state = super().describe_state() + return state + + @property + def _host_db_client(self) -> DatabaseClient: + """Return the database client that is installed on the same machine as the DataManipulationBot.""" + db_client = self.software_manager.software.get("DatabaseClient") + if db_client is None: + _LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.") + return db_client def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -76,8 +102,8 @@ class DataManipulationBot(DatabaseClient): :param repeat: Whether to repeat attacking once finished. """ self.server_ip_address = server_ip_address - self.payload = payload self.server_password = server_password + self.payload = payload self.port_scan_p_of_success = port_scan_p_of_success self.data_manipulation_p_of_success = data_manipulation_p_of_success self.repeat = repeat @@ -123,15 +149,17 @@ class DataManipulationBot(DatabaseClient): :param p_of_success: Probability of successfully performing data manipulation, by default 0.1. """ + self._host_db_client.server_ip_address = self.server_ip_address + self._host_db_client.server_password = self.server_password if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: # perform the actual data manipulation attack if simulate_trial(p_of_success): self.sys_log.info(f"{self.name}: Performing data manipulation") # perform the attack - if not len(self.connections): - self.connect() - if len(self.connections): - self.query(self.payload) + if not len(self._host_db_client.connections): + self._host_db_client.connect() + if len(self._host_db_client.connections): + self._host_db_client.query(self.payload) self.sys_log.info(f"{self.name} payload delivered: {self.payload}") attack_successful = True if attack_successful: