Pass execution definition from config to agent
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
@@ -8,13 +8,21 @@ from pydantic import BaseModel
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
|
||||
ObsType: TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
|
||||
class AgentExecutionDefinition(BaseModel):
|
||||
"""Additional configuration for agents."""
|
||||
|
||||
port_scan_p_of_success: float = 0.1
|
||||
"The probability of a port scan succeeding."
|
||||
data_manipulation_p_of_success: float = 0.1
|
||||
"The probability of data manipulation succeeding."
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
@@ -26,7 +34,7 @@ class AbstractAgent(ABC):
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
execution_definition: Optional[AgentExecutionDefinition]
|
||||
execution_definition: Optional[AgentExecutionDefinition],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
@@ -116,13 +124,38 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
|
||||
class DataManipulationAgent(AbstractScriptedAgent):
|
||||
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# get node ids that are part of the agent's observation space
|
||||
node_ids: List[str] = [n.where[-1] for n in self.observation_space.obs.nodes]
|
||||
# get all nodes from their ids
|
||||
nodes: List[Node] = [n for n_id, n in self.action_space.sim.network.nodes.items() if n_id in node_ids]
|
||||
|
||||
# get execution definition for data manipulation bots
|
||||
for node in nodes:
|
||||
bot_sw: Optional["DataManipulationBot"] = node.software_manager.software.get("DataManipulationBot")
|
||||
|
||||
if bot_sw is not None:
|
||||
bot_sw.execution_definition = self.execution_definition
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
|
||||
:param obs: _description_
|
||||
:type obs: ObsType
|
||||
:param reward: _description_, defaults to None
|
||||
:type reward: float, optional
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
"""Base class for actors controlled via external messages, such as RL policies."""
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent, DataManipulationAgent, AgentExecutionDefinition
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentExecutionDefinition, DataManipulationAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
|
||||
@@ -2,6 +2,7 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.game.agent.interface import AgentExecutionDefinition
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
@@ -14,6 +15,7 @@ class DataManipulationAttackStage(IntEnum):
|
||||
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the
|
||||
simulation. Each stage represents a specific phase in the attack process.
|
||||
"""
|
||||
|
||||
NOT_STARTED = 0
|
||||
"Indicates that the attack has not started yet."
|
||||
LOGON = 1
|
||||
@@ -30,17 +32,19 @@ class DataManipulationAttackStage(IntEnum):
|
||||
|
||||
class DataManipulationBot(DatabaseClient):
|
||||
"""A bot that simulates a script which performs a SQL injection attack."""
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
payload: Optional[str] = None
|
||||
server_password: Optional[str] = None
|
||||
attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED
|
||||
execution_definition: AgentExecutionDefinition = AgentExecutionDefinition()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = "DataManipulationBot"
|
||||
|
||||
def configure(
|
||||
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
|
||||
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Configure the DataManipulatorBot to communicate with a DatabaseService.
|
||||
@@ -96,7 +100,6 @@ class DataManipulationBot(DatabaseClient):
|
||||
if self.attack_stage == DataManipulationAttackStage.PORT_SCAN:
|
||||
# perform the actual data manipulation attack
|
||||
if simulate_trial(p_of_success):
|
||||
|
||||
self.sys_log.info(f"{self.name}: Performing port scan")
|
||||
# perform the attack
|
||||
if not self.connected:
|
||||
@@ -114,7 +117,7 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
def execute(self):
|
||||
"""
|
||||
Execute the Data Manipulation Bot
|
||||
Execute the Data Manipulation Bot.
|
||||
|
||||
Calls the parent classes execute method before starting the application loop.
|
||||
"""
|
||||
@@ -127,7 +130,6 @@ class DataManipulationBot(DatabaseClient):
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
|
||||
if self.operating_state != ApplicationOperatingState.RUNNING:
|
||||
return
|
||||
if self.server_ip_address and self.payload and self.operating_state:
|
||||
|
||||
Reference in New Issue
Block a user