Files
PrimAITE/src/primaite/game/agent/scripted_agents/random_agent.py

122 lines
5.0 KiB
Python

# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import random
from functools import cached_property
from typing import Dict, List, Tuple
from gymnasium.core import ObsType
from pydantic import computed_field, Field, model_validator
from primaite.game.agent.interface import AbstractScriptedAgent
__all__ = ("RandomAgent", "PeriodicAgent")
class RandomAgent(AbstractScriptedAgent, identifier="RandomAgent"):
"""Agent that ignores its observation and acts completely at random."""
config: "RandomAgent.ConfigSchema" = Field(default_factory=lambda: RandomAgent.ConfigSchema())
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Random Agents."""
type: str = "RandomAgent"
def get_action(self) -> Tuple[str, Dict]:
"""Sample the action space randomly.
:param obs: Current observation for this agent, not used in RandomAgent
:type obs: ObsType
:param timestep: The current simulation timestep, not used in RandomAgent
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.action_manager.space.sample())
class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
"""Agent that does nothing most of the time, but executes application at regular intervals (with variance)."""
config: "PeriodicAgent.ConfigSchema" = Field(default_factory=lambda: PeriodicAgent.ConfigSchema())
class AgentSettingsSchema(AbstractScriptedAgent.AgentSettingsSchema):
"""Schema for the `agent_settings` part of the agent config."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
start_variance: int = 0
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
max_executions: int = 999999
possible_start_nodes: List[str]
target_application: str
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance).
If variance were greater than frequency, sometimes the bracketed term would be negative
and the attack would never happen again.
"""
if self.variance >= self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration Schema for Periodic Agent."""
type: str = "PeriodicAgent"
"""Name of the agent."""
agent_settings: "PeriodicAgent.AgentSettingsSchema" = Field(
default_factory=lambda: PeriodicAgent.AgentSettingsSchema()
)
num_executions: int = 0
"""Number of times the agent has executed an action."""
next_execution_timestep: int = 0
"""Timestep of the next action execution by the agent."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._set_next_execution_timestep(
timestep=self.config.agent_settings.start_step, variance=self.config.agent_settings.start_variance
)
@computed_field
@cached_property
def start_node(self) -> str:
"""On instantiation, randomly select a start node."""
return random.choice(self.config.agent_settings.possible_start_nodes)
def _set_next_execution_timestep(self, timestep: int, variance: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep when the next execute action should be taken.
:type timestep: int
:param variance: Uniform random variance applied to the timestep
:type variance: int
"""
random_increment = random.randint(-variance, variance)
self.next_execution_timestep = timestep + random_increment
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Do nothing, unless the current timestep is the next execution timestep, in which case do the action."""
if timestep == self.next_execution_timestep and self.num_executions < self.config.agent_settings.max_executions:
self.num_executions += 1
self._set_next_execution_timestep(
timestep + self.config.agent_settings.frequency, self.config.agent_settings.variance
)
return "node_application_execute", {
"node_name": self.start_node,
"application_name": self.config.agent_settings.target_application,
}
return "do_nothing", {}