From 2975aa882774c3b5979072646de64c243ab880b4 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Tue, 21 Nov 2023 11:42:01 +0000 Subject: [PATCH] Execute data manipulation bots from agent --- src/primaite/game/agent/interface.py | 38 ++++++++++++++++++- src/primaite/game/session.py | 4 +- .../system/applications/database_client.py | 2 +- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 70eb1980..94878947 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -24,6 +24,20 @@ class AgentExecutionDefinition(BaseModel): data_manipulation_p_of_success: float = 0.1 "The probability of data manipulation succeeding." + @classmethod + def from_config(cls, config: Optional[Dict]) -> "AgentExecutionDefinition": + """Construct an AgentExecutionDefinition from a config dictionary. + + :param config: A dict of options for the execution definition. + :type config: Dict + :return: The execution definition. + :rtype: AgentExecutionDefinition + """ + if config is None: + return cls() + + return cls(**config) + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -42,6 +56,20 @@ class AgentSettings(BaseModel): start_settings: Optional[AgentStartSettings] = None "Configuration for when an agent begins performing it's actions" + @classmethod + def from_config(cls, config: Optional[Dict]) -> "AgentSettings": + """Construct agent settings from a config dictionary. + + :param config: A dict of options for the agent settings. + :type config: Dict + :return: The agent settings. + :rtype: AgentSettings + """ + if config is None: + return cls() + + return cls(**config) + class AbstractAgent(ABC): """Base class for scripted and RL agents.""" @@ -149,6 +177,8 @@ class RandomAgent(AbstractScriptedAgent): class DataManipulationAgent(AbstractScriptedAgent): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + data_manipulation_bots: List["DataManipulationBot"] = [] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -163,6 +193,7 @@ class DataManipulationAgent(AbstractScriptedAgent): if bot_sw is not None: bot_sw.execution_definition = self.execution_definition + self.data_manipulation_bots.append(bot_sw) def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. @@ -174,7 +205,12 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - return self.action_space.get_action(self.action_space.space.sample()) + # TODO: Move this to the appropriate place + # return self.action_space.get_action(self.action_space.space.sample()) + for bot in self.data_manipulation_bots: + bot.execute() + + return ("DONOTHING", {"dummy": 0}) class AbstractGATEAgent(AbstractAgent): diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 9701fec9..1b086c35 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -444,8 +444,8 @@ class PrimaiteSession: # CREATE REWARD FUNCTION rew_function = RewardFunction.from_config(reward_function_cfg, session=sess) - execution_definition = AgentExecutionDefinition(**agent_cfg.get("execution_definition", {})) - agent_settings = AgentSettings(**agent_cfg.get("agent_settings", {})) + execution_definition = AgentExecutionDefinition.from_config(agent_cfg.get("execution_definition")) + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) # CREATE AGENT if agent_type == "GreenWebBrowsingAgent": diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 28e826fd..e15249e3 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -130,7 +130,7 @@ class DatabaseClient(Application): def execute(self) -> None: """Run the DatabaseClient.""" - super().execute() + # super().execute() if self.operating_state == ApplicationOperatingState.RUNNING: self.connect()