From 505eab6ed91f520d672775c18021706bd8c3a578 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 3 Jan 2025 14:02:36 +0000 Subject: [PATCH] #2869 - Changes following review discussion --- src/primaite/game/agent/interface.py | 25 +++++++++++++------------ src/primaite/game/game.py | 22 +++++----------------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 14416241..b980d748 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -49,13 +49,12 @@ class AbstractAgent(BaseModel): _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} _logger: AgentLog = AgentLog(agent_name="Abstract_Agent") - config: "AbstractAgent.ConfigSchema" history: List[AgentHistoryItem] = [] + config: "AbstractAgent.ConfigSchema" action_manager: ActionManager observation_manager: ObservationManager reward_function: RewardFunction - class ConfigSchema(BaseModel): """ Configuration Schema for AbstractAgents. @@ -85,14 +84,14 @@ class AbstractAgent(BaseModel): variance: int = 0 "The amount the frequency can randomly change to" - @model_validator(mode="after") def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema": """ Make sure variance is equal to or lower than frequency. - This is because the calculation for the next execution time is now + (frequency +- variance). If variance were - greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again. + This is because the calculation for the next execution time is now + (frequency +- variance). + If variance were greater than frequency, sometimes the bracketed term would be negative + and the attack would never happen again. """ if self.variance > self.frequency: raise ValueError( @@ -101,14 +100,12 @@ class AbstractAgent(BaseModel): ) return self - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: if identifier in cls._registry: raise ValueError(f"Cannot create a new agent under reserved name {identifier}") cls._registry[identifier] = cls super().__init_subclass__(**kwargs) - @property def flatten_obs(self) -> bool: """Return agent flatten_obs param.""" @@ -117,7 +114,12 @@ class AbstractAgent(BaseModel): @classmethod def from_config(cls, config: Dict) -> "AbstractAgent": """Creates an agent component from a configuration dictionary.""" - obj = cls(config=cls.ConfigSchema(**config)) + obj = cls( + config=cls.ConfigSchema(**config["agent_settings"]), + action_manager=ActionManager.from_config(**config["action_manager"]), + observation_manager=ObservationManager.from_config(**config["observation_space"]), + reward_function=RewardFunction.from_config(**config["reward_function"]), + ) return obj def update_observation(self, state: Dict) -> ObsType: @@ -206,9 +208,8 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): """Configuration Schema for Proxy Agent.""" agent_name: str = "Proxy_Agent" - agent_settings: AgentSettings = None - flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - action_masking: bool = agent_settings.action_masking if agent_settings else False + flatten_obs: bool = False + action_masking: bool = False def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ @@ -221,7 +222,7 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): :return: Action to be taken in CAOS format. :rtype: Tuple[str, Dict] """ - return self.config.action_manager.get_action(self.most_recent_action) + return self.action_manager.get_action(self.most_recent_action) def store_action(self, action: ActType): """ diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 781db2c5..e83f59a6 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -7,10 +7,8 @@ import numpy as np from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations.observation_manager import ObservationManager -from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.interface import AbstractAgent, ProxyAgent +from primaite.game.agent.rewards import SharedReward from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.creation import NetworkNodeAdder @@ -532,24 +530,14 @@ class PrimaiteGame: action_space_cfg = agent_cfg["action_space"] observation_space_cfg = agent_cfg["observation_space"] reward_function_cfg = agent_cfg["reward_function"] - - # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(observation_space_cfg) - - # CREATE ACTION SPACE - action_space = ActionManager.from_config(game, action_space_cfg) - - # CREATE REWARD FUNCTION - reward_function = RewardFunction.from_config(reward_function_cfg) + agent_settings = agent_cfg["agent_settings"] # CREATE AGENT - - agent_settings = agent_cfg["agent_settings"] agent_config = { "agent_name": agent_ref, - "action_manager": action_space, - "observation_manager": obs_space, - "reward_function": reward_function, + "action_manager": action_space_cfg, + "observation_manager": observation_space_cfg, + "reward_function": reward_function_cfg, "agent_settings": agent_settings, }