#2869 - Update type hints and ConfigSchema variables in some agent classes

This commit is contained in:
Charlie Crane
2025-01-13 15:08:48 +00:00
parent 32fc970cfe
commit edd2668ea4
4 changed files with 33 additions and 35 deletions

View File

@@ -6,7 +6,7 @@ from abc import abstractmethod
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.agent_log import AgentLog
@@ -72,32 +72,6 @@ class AbstractAgent(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
type: str = "AbstractAgent"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance).
If variance were greater than frequency, sometimes the bracketed term would be negative
and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
if identifier in cls._registry:

View File

@@ -1,10 +1,11 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
"""Agents with predefined behaviours."""
from typing import Any, Dict, Tuple
from typing import Dict, Tuple
import numpy as np
import pydantic
from gymnasium.core import ObsType
from numpy.random import Generator
from pydantic import Field
from primaite.game.agent.interface import AbstractScriptedAgent
@@ -16,7 +17,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema())
rng: Any = np.random.default_rng(np.random.randint(0, 65535))
rng: Generator = np.random.default_rng(np.random.randint(0, 65535))
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
"""Configuration schema for Probabilistic Agent."""

View File

@@ -3,7 +3,7 @@ import random
from typing import Dict, Tuple
from gymnasium.core import ObsType
from pydantic import Field
from pydantic import Field, model_validator
from primaite.game.agent.interface import AbstractScriptedAgent
@@ -43,6 +43,28 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
type: str = "PeriodicAgent"
"""Name of the agent."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance).
If variance were greater than frequency, sometimes the bracketed term would be negative
and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
max_executions: int = 999999
"Maximum number of times the agent can execute its action."

View File

@@ -3,7 +3,7 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import PrimaiteGame
from primaite.game.game import PrimaiteGame, PrimaiteGameOptions
def test_probabilistic_agent():
@@ -54,16 +54,17 @@ def test_probabilistic_agent():
},
"options": {},
}
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
game = PrimaiteGame()
game.options = PrimaiteGameOptions(ports=[], protocols=[])
observation_space_cfg = None
reward_function_cfg = {}
pa_config = {
"agent_name": "test_agent",
"game": PrimaiteGame(),
"type": "ProbabilisticAgent",
"game": game,
"action_manager": action_space_cfg,
"observation_manager": observation_space_cfg,
"reward_function": reward_function_cfg,