"""Interface for agents.""" import random from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium.core import ActType, ObsType from pydantic import BaseModel from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction if TYPE_CHECKING: from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" start_step: int = 5 "The timestep at which an agent begins performing it's actions" frequency: int = 5 "The number of timesteps to wait between performing actions" variance: int = 0 "The amount the frequency can randomly change to" class AgentSettings(BaseModel): """Settings for configuring the operation of an agent.""" start_settings: Optional[AgentStartSettings] = None "Configuration for when an agent begins performing it's actions" @classmethod def from_config(cls, config: Optional[Dict]) -> "AgentSettings": """Construct agent settings from a config dictionary. :param config: A dict of options for the agent settings. :type config: Dict :return: The agent settings. :rtype: AgentSettings """ if config is None: return cls() return cls(**config) class AbstractAgent(ABC): """Base class for scripted and RL agents.""" def __init__( self, agent_name: Optional[str], action_space: Optional[ActionManager], observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], agent_settings: Optional[AgentSettings] = None, ) -> None: """ Initialize an agent. :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes. :type agent_name: Optional[str] :param action_space: Action space for the agent. :type action_space: Optional[ActionManager] :param observation_space: Observation space for the agent. :type observation_space: Optional[ObservationSpace] :param reward_function: Reward function for the agent. :type reward_function: Optional[RewardFunction] """ self.agent_name: str = agent_name or "unnamed_agent" self.action_manager: Optional[ActionManager] = action_space self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ return self.observation_manager.update(state) def update_reward(self, state: Dict) -> float: """ Use the reward function to calculate a reward from the state. :param state: State of the environment. :type state: Dict :return: Reward from the state. :rtype: float """ return self.reward_function.update(state) @abstractmethod def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. Subclasses should implement agent logic here. It should use the observation as input to decide best next action. :param obs: Observation of the environment. :type obs: ObsType :param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted? :type reward: float, optional :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ # in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action return ("DO_NOTHING", {}) def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]: # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" request = self.action_manager.form_request(action_identifier=action, action_options=options) return request class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" ... class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ :type obs: ObsType :param reward: _description_, defaults to None :type reward: float, optional :return: _description_ :rtype: Tuple[str, Dict] """ return self.action_manager.get_action(self.action_manager.space.sample()) class ProxyAgent(AbstractAgent): """Agent that sends observations to an RL model and receives actions from that model.""" def __init__( self, agent_name: Optional[str], action_space: Optional[ActionManager], observation_space: Optional[ObservationManager], reward_function: Optional[RewardFunction], ) -> None: super().__init__( agent_name=agent_name, action_space=action_space, observation_space=observation_space, reward_function=reward_function, ) self.most_recent_action: ActType def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. :type obs: ObsType :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. :type reward: float, optional :return: Action to be taken in CAOS format. :rtype: Tuple[str, Dict] """ return self.action_manager.get_action(self.most_recent_action) def store_action(self, action: ActType): """ Store the most recent action taken by the agent. The environment is responsible for calling this method when it receives an action from the agent policy. """ self.most_recent_action = action class DataManipulationAgent(AbstractScriptedAgent): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" data_manipulation_bots: List["DataManipulationBot"] = [] next_execution_timestep: int = 0 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) def _set_next_execution_timestep(self, timestep: int) -> None: """Set the next execution timestep with a configured random variance. :param timestep: The timestep to add variance to. """ random_timestep_increment = random.randint( -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance ) self.next_execution_timestep = timestep + random_timestep_increment def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ :type obs: ObsType :param reward: _description_, defaults to None :type reward: float, optional :return: _description_ :rtype: Tuple[str, Dict] """ current_timestep = self.action_manager.session.step_counter if current_timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} class AbstractGATEAgent(AbstractAgent): """Base class for actors controlled via external messages, such as RL policies.""" ...