diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 26445830..ac76a425 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -6,7 +6,7 @@ from abc import abstractmethod from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING from gymnasium.core import ActType, ObsType -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field from primaite.game.agent.actions import ActionManager from primaite.game.agent.agent_log import AgentLog @@ -72,32 +72,6 @@ class AbstractAgent(BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) type: str = "AbstractAgent" - flatten_obs: bool = True - "Whether to flatten the observation space before passing it to the agent. True by default." - action_masking: bool = False - "Whether to return action masks at each step." - start_step: int = 5 - "The timestep at which an agent begins performing it's actions" - frequency: int = 5 - "The number of timesteps to wait between performing actions" - variance: int = 0 - "The amount the frequency can randomly change to" - - @model_validator(mode="after") - def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema": - """ - Make sure variance is equal to or lower than frequency. - - This is because the calculation for the next execution time is now + (frequency +- variance). - If variance were greater than frequency, sometimes the bracketed term would be negative - and the attack would never happen again. - """ - if self.variance > self.frequency: - raise ValueError( - f"Agent start settings error: variance must be lower than frequency " - f"{self.variance=}, {self.frequency=}" - ) - return self def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: if identifier in cls._registry: diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index f3d9ee08..8e714f55 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -1,10 +1,11 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Agents with predefined behaviours.""" -from typing import Any, Dict, Tuple +from typing import Dict, Tuple import numpy as np import pydantic from gymnasium.core import ObsType +from numpy.random import Generator from pydantic import Field from primaite.game.agent.interface import AbstractScriptedAgent @@ -16,7 +17,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent") """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema()) - rng: Any = np.random.default_rng(np.random.randint(0, 65535)) + rng: Generator = np.random.default_rng(np.random.randint(0, 65535)) class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration schema for Probabilistic Agent.""" diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index daf810a8..b5601a58 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -3,7 +3,7 @@ import random from typing import Dict, Tuple from gymnasium.core import ObsType -from pydantic import Field +from pydantic import Field, model_validator from primaite.game.agent.interface import AbstractScriptedAgent @@ -43,6 +43,28 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"): type: str = "PeriodicAgent" """Name of the agent.""" + start_step: int = 5 + "The timestep at which an agent begins performing it's actions" + frequency: int = 5 + "The number of timesteps to wait between performing actions" + variance: int = 0 + "The amount the frequency can randomly change to" + + @model_validator(mode="after") + def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema": + """ + Make sure variance is equal to or lower than frequency. + + This is because the calculation for the next execution time is now + (frequency +- variance). + If variance were greater than frequency, sometimes the bracketed term would be negative + and the attack would never happen again. + """ + if self.variance > self.frequency: + raise ValueError( + f"Agent start settings error: variance must be lower than frequency " + f"{self.variance=}, {self.frequency=}" + ) + return self max_executions: int = 999999 "Maximum number of times the agent can execute its action." diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index 6e5fb94d..7035e98f 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -3,7 +3,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent -from primaite.game.game import PrimaiteGame +from primaite.game.game import PrimaiteGame, PrimaiteGameOptions def test_probabilistic_agent(): @@ -54,16 +54,17 @@ def test_probabilistic_agent(): }, "options": {}, } - observation_space = ObservationManager(NestedObservation(components={})) - reward_function = RewardFunction() + + game = PrimaiteGame() + game.options = PrimaiteGameOptions(ports=[], protocols=[]) observation_space_cfg = None reward_function_cfg = {} pa_config = { - "agent_name": "test_agent", - "game": PrimaiteGame(), + "type": "ProbabilisticAgent", + "game": game, "action_manager": action_space_cfg, "observation_manager": observation_space_cfg, "reward_function": reward_function_cfg,