Pass execution definition from config to agent

This commit is contained in:
Jake Walker
2023-11-17 11:51:19 +00:00
parent 1c5ff66d26
commit 227e73602f
3 changed files with 42 additions and 7 deletions

View File

@@ -1,6 +1,6 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union
import numpy as np
from pydantic import BaseModel
@@ -8,13 +8,21 @@ from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Node
if TYPE_CHECKING:
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
ObsType: TypeAlias = Union[Dict, np.ndarray]
class AgentExecutionDefinition(BaseModel):
"""Additional configuration for agents."""
port_scan_p_of_success: float = 0.1
"The probability of a port scan succeeding."
data_manipulation_p_of_success: float = 0.1
"The probability of data manipulation succeeding."
class AbstractAgent(ABC):
@@ -26,7 +34,7 @@ class AbstractAgent(ABC):
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
execution_definition: Optional[AgentExecutionDefinition]
execution_definition: Optional[AgentExecutionDefinition],
) -> None:
"""
Initialize an agent.
@@ -116,13 +124,38 @@ class RandomAgent(AbstractScriptedAgent):
"""
return self.action_space.get_action(self.action_space.space.sample())
class DataManipulationAgent(AbstractScriptedAgent):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# get node ids that are part of the agent's observation space
node_ids: List[str] = [n.where[-1] for n in self.observation_space.obs.nodes]
# get all nodes from their ids
nodes: List[Node] = [n for n_id, n in self.action_space.sim.network.nodes.items() if n_id in node_ids]
# get execution definition for data manipulation bots
for node in nodes:
bot_sw: Optional["DataManipulationBot"] = node.software_manager.software.get("DataManipulationBot")
if bot_sw is not None:
bot_sw.execution_definition = self.execution_definition
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:rtype: Tuple[str, Dict]
"""
return self.action_space.get_action(self.action_space.space.sample())
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""

View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, RandomAgent, DataManipulationAgent, AgentExecutionDefinition
from primaite.game.agent.interface import AbstractAgent, AgentExecutionDefinition, DataManipulationAgent, RandomAgent
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Link, NIC, Node

View File

@@ -2,6 +2,7 @@ from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from primaite.game.agent.interface import AgentExecutionDefinition
from primaite.game.science import simulate_trial
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
@@ -14,6 +15,7 @@ class DataManipulationAttackStage(IntEnum):
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the
simulation. Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
LOGON = 1
@@ -30,17 +32,19 @@ class DataManipulationAttackStage(IntEnum):
class DataManipulationBot(DatabaseClient):
"""A bot that simulates a script which performs a SQL injection attack."""
server_ip_address: Optional[IPv4Address] = None
payload: Optional[str] = None
server_password: Optional[str] = None
attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED
execution_definition: AgentExecutionDefinition = AgentExecutionDefinition()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def configure(
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.
@@ -96,7 +100,6 @@ class DataManipulationBot(DatabaseClient):
if self.attack_stage == DataManipulationAttackStage.PORT_SCAN:
# perform the actual data manipulation attack
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing port scan")
# perform the attack
if not self.connected:
@@ -114,7 +117,7 @@ class DataManipulationBot(DatabaseClient):
def execute(self):
"""
Execute the Data Manipulation Bot
Execute the Data Manipulation Bot.
Calls the parent classes execute method before starting the application loop.
"""
@@ -127,7 +130,6 @@ class DataManipulationBot(DatabaseClient):
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if self.operating_state != ApplicationOperatingState.RUNNING:
return
if self.server_ip_address and self.payload and self.operating_state: