From 227e73602f8468523da60e8fe983622959d9ae92 Mon Sep 17 00:00:00 2001 From: Jake Walker Date: Fri, 17 Nov 2023 11:51:19 +0000 Subject: [PATCH] Pass execution definition from config to agent --- src/primaite/game/agent/interface.py | 37 ++++++++++++++++++- src/primaite/game/session.py | 2 +- .../red_services/data_manipulation_bot.py | 10 +++-- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index d04b298e..c591c554 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,6 +1,6 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TypeAlias, Union +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union import numpy as np from pydantic import BaseModel @@ -8,13 +8,21 @@ from pydantic import BaseModel from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction +from primaite.simulator.network.hardware.base import Node + +if TYPE_CHECKING: + from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot ObsType: TypeAlias = Union[Dict, np.ndarray] class AgentExecutionDefinition(BaseModel): + """Additional configuration for agents.""" + port_scan_p_of_success: float = 0.1 + "The probability of a port scan succeeding." data_manipulation_p_of_success: float = 0.1 + "The probability of data manipulation succeeding." class AbstractAgent(ABC): @@ -26,7 +34,7 @@ class AbstractAgent(ABC): action_space: Optional[ActionManager], observation_space: Optional[ObservationSpace], reward_function: Optional[RewardFunction], - execution_definition: Optional[AgentExecutionDefinition] + execution_definition: Optional[AgentExecutionDefinition], ) -> None: """ Initialize an agent. @@ -116,13 +124,38 @@ class RandomAgent(AbstractScriptedAgent): """ return self.action_space.get_action(self.action_space.space.sample()) + class DataManipulationAgent(AbstractScriptedAgent): + """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # get node ids that are part of the agent's observation space + node_ids: List[str] = [n.where[-1] for n in self.observation_space.obs.nodes] + # get all nodes from their ids + nodes: List[Node] = [n for n_id, n in self.action_space.sim.network.nodes.items() if n_id in node_ids] + + # get execution definition for data manipulation bots + for node in nodes: + bot_sw: Optional["DataManipulationBot"] = node.software_manager.software.get("DataManipulationBot") + + if bot_sw is not None: + bot_sw.execution_definition = self.execution_definition + def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + """Randomly sample an action from the action space. + + :param obs: _description_ + :type obs: ObsType + :param reward: _description_, defaults to None + :type reward: float, optional + :return: _description_ + :rtype: Tuple[str, Dict] + """ return self.action_space.get_action(self.action_space.space.sample()) + class AbstractGATEAgent(AbstractAgent): """Base class for actors controlled via external messages, such as RL policies.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 082ed281..5f3fb7b9 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, RandomAgent, DataManipulationAgent, AgentExecutionDefinition +from primaite.game.agent.interface import AbstractAgent, AgentExecutionDefinition, DataManipulationAgent, RandomAgent from primaite.game.agent.observations import ObservationSpace from primaite.game.agent.rewards import RewardFunction from primaite.simulator.network.hardware.base import Link, NIC, Node diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index aec7bbd8..35ea413a 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -2,6 +2,7 @@ from enum import IntEnum from ipaddress import IPv4Address from typing import Optional +from primaite.game.agent.interface import AgentExecutionDefinition from primaite.game.science import simulate_trial from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient @@ -14,6 +15,7 @@ class DataManipulationAttackStage(IntEnum): This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the simulation. Each stage represents a specific phase in the attack process. """ + NOT_STARTED = 0 "Indicates that the attack has not started yet." LOGON = 1 @@ -30,17 +32,19 @@ class DataManipulationAttackStage(IntEnum): class DataManipulationBot(DatabaseClient): """A bot that simulates a script which performs a SQL injection attack.""" + server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None server_password: Optional[str] = None attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED + execution_definition: AgentExecutionDefinition = AgentExecutionDefinition() def __init__(self, **kwargs): super().__init__(**kwargs) self.name = "DataManipulationBot" def configure( - self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None + self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None ): """ Configure the DataManipulatorBot to communicate with a DatabaseService. @@ -96,7 +100,6 @@ class DataManipulationBot(DatabaseClient): if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: # perform the actual data manipulation attack if simulate_trial(p_of_success): - self.sys_log.info(f"{self.name}: Performing port scan") # perform the attack if not self.connected: @@ -114,7 +117,7 @@ class DataManipulationBot(DatabaseClient): def execute(self): """ - Execute the Data Manipulation Bot + Execute the Data Manipulation Bot. Calls the parent classes execute method before starting the application loop. """ @@ -127,7 +130,6 @@ class DataManipulationBot(DatabaseClient): This is the core loop where the bot sequentially goes through the stages of the attack. """ - if self.operating_state != ApplicationOperatingState.RUNNING: return if self.server_ip_address and self.payload and self.operating_state: