Files
PrimAITE/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py

57 lines
2.4 KiB
Python

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class DataManipulationAgent(AbstractScriptedAgent):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
next_execution_timestep: int = 0
starting_node_idx: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup_agent()
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
:param obs: Current observation for this agent, not used in DataManipulationAgent
:type obs: ObsType
:param timestep: The current simulation timestep, used for scheduling actions
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def setup_agent(self) -> None:
"""Set the next execution timestep when the episode resets."""
self._select_start_node()
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _select_start_node(self) -> None:
"""Set the starting starting node of the agent to be a random node from this agent's action manager."""
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)