#2869 - Changes following review discussion

This commit is contained in:
Charlie Crane
2025-01-03 14:02:36 +00:00
parent 55ddcb7eb4
commit 505eab6ed9
2 changed files with 18 additions and 29 deletions

View File

@@ -49,13 +49,12 @@ class AbstractAgent(BaseModel):
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
_logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
config: "AbstractAgent.ConfigSchema"
history: List[AgentHistoryItem] = []
config: "AbstractAgent.ConfigSchema"
action_manager: ActionManager
observation_manager: ObservationManager
reward_function: RewardFunction
class ConfigSchema(BaseModel):
"""
Configuration Schema for AbstractAgents.
@@ -85,14 +84,14 @@ class AbstractAgent(BaseModel):
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
This is because the calculation for the next execution time is now + (frequency +- variance).
If variance were greater than frequency, sometimes the bracketed term would be negative
and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
@@ -101,14 +100,12 @@ class AbstractAgent(BaseModel):
)
return self
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
if identifier in cls._registry:
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
cls._registry[identifier] = cls
super().__init_subclass__(**kwargs)
@property
def flatten_obs(self) -> bool:
"""Return agent flatten_obs param."""
@@ -117,7 +114,12 @@ class AbstractAgent(BaseModel):
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
obj = cls(config=cls.ConfigSchema(**config))
obj = cls(
config=cls.ConfigSchema(**config["agent_settings"]),
action_manager=ActionManager.from_config(**config["action_manager"]),
observation_manager=ObservationManager.from_config(**config["observation_space"]),
reward_function=RewardFunction.from_config(**config["reward_function"]),
)
return obj
def update_observation(self, state: Dict) -> ObsType:
@@ -206,9 +208,8 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
"""Configuration Schema for Proxy Agent."""
agent_name: str = "Proxy_Agent"
agent_settings: AgentSettings = None
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
action_masking: bool = agent_settings.action_masking if agent_settings else False
flatten_obs: bool = False
action_masking: bool = False
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
@@ -221,7 +222,7 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
return self.config.action_manager.get_action(self.most_recent_action)
return self.action_manager.get_action(self.most_recent_action)
def store_action(self, action: ActType):
"""

View File

@@ -7,10 +7,8 @@ import numpy as np
from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.rewards import SharedReward
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.creation import NetworkNodeAdder
@@ -532,24 +530,14 @@ class PrimaiteGame:
action_space_cfg = agent_cfg["action_space"]
observation_space_cfg = agent_cfg["observation_space"]
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationManager.from_config(observation_space_cfg)
# CREATE ACTION SPACE
action_space = ActionManager.from_config(game, action_space_cfg)
# CREATE REWARD FUNCTION
reward_function = RewardFunction.from_config(reward_function_cfg)
agent_settings = agent_cfg["agent_settings"]
# CREATE AGENT
agent_settings = agent_cfg["agent_settings"]
agent_config = {
"agent_name": agent_ref,
"action_manager": action_space,
"observation_manager": obs_space,
"reward_function": reward_function,
"action_manager": action_space_cfg,
"observation_manager": observation_space_cfg,
"reward_function": reward_function_cfg,
"agent_settings": agent_settings,
}