#2869 - Update type hints and ConfigSchema variables in some agent classes
This commit is contained in:
@@ -6,7 +6,7 @@ from abc import abstractmethod
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.agent_log import AgentLog
|
||||
@@ -72,32 +72,6 @@ class AbstractAgent(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
type: str = "AbstractAgent"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
action_masking: bool = False
|
||||
"Whether to return action masks at each step."
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema":
|
||||
"""
|
||||
Make sure variance is equal to or lower than frequency.
|
||||
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance).
|
||||
If variance were greater than frequency, sometimes the bracketed term would be negative
|
||||
and the attack would never happen again.
|
||||
"""
|
||||
if self.variance > self.frequency:
|
||||
raise ValueError(
|
||||
f"Agent start settings error: variance must be lower than frequency "
|
||||
f"{self.variance=}, {self.frequency=}"
|
||||
)
|
||||
return self
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
if identifier in cls._registry:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""Agents with predefined behaviours."""
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
from gymnasium.core import ObsType
|
||||
from numpy.random import Generator
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
@@ -16,7 +17,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent")
|
||||
"""Scripted agent which randomly samples its action space with prescribed probabilities for each action."""
|
||||
|
||||
config: "ProbabilisticAgent.ConfigSchema" = Field(default_factory=lambda: ProbabilisticAgent.ConfigSchema())
|
||||
rng: Any = np.random.default_rng(np.random.randint(0, 65535))
|
||||
rng: Generator = np.random.default_rng(np.random.randint(0, 65535))
|
||||
|
||||
class ConfigSchema(AbstractScriptedAgent.ConfigSchema):
|
||||
"""Configuration schema for Probabilistic Agent."""
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
@@ -43,6 +43,28 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="PeriodicAgent"):
|
||||
|
||||
type: str = "PeriodicAgent"
|
||||
"""Name of the agent."""
|
||||
start_step: int = 5
|
||||
"The timestep at which an agent begins performing it's actions"
|
||||
frequency: int = 5
|
||||
"The number of timesteps to wait between performing actions"
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "PeriodicAgent.ConfigSchema":
|
||||
"""
|
||||
Make sure variance is equal to or lower than frequency.
|
||||
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance).
|
||||
If variance were greater than frequency, sometimes the bracketed term would be negative
|
||||
and the attack would never happen again.
|
||||
"""
|
||||
if self.variance > self.frequency:
|
||||
raise ValueError(
|
||||
f"Agent start settings error: variance must be lower than frequency "
|
||||
f"{self.variance=}, {self.frequency=}"
|
||||
)
|
||||
return self
|
||||
|
||||
max_executions: int = 999999
|
||||
"Maximum number of times the agent can execute its action."
|
||||
|
||||
@@ -3,7 +3,7 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.game.game import PrimaiteGame, PrimaiteGameOptions
|
||||
|
||||
|
||||
def test_probabilistic_agent():
|
||||
@@ -54,16 +54,17 @@ def test_probabilistic_agent():
|
||||
},
|
||||
"options": {},
|
||||
}
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
game = PrimaiteGame()
|
||||
game.options = PrimaiteGameOptions(ports=[], protocols=[])
|
||||
|
||||
observation_space_cfg = None
|
||||
|
||||
reward_function_cfg = {}
|
||||
|
||||
pa_config = {
|
||||
"agent_name": "test_agent",
|
||||
"game": PrimaiteGame(),
|
||||
"type": "ProbabilisticAgent",
|
||||
"game": game,
|
||||
"action_manager": action_space_cfg,
|
||||
"observation_manager": observation_space_cfg,
|
||||
"reward_function": reward_function_cfg,
|
||||
|
||||
Reference in New Issue
Block a user