#2869 - Changes following review discussion
This commit is contained in:
@@ -49,13 +49,12 @@ class AbstractAgent(BaseModel):
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
_logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
|
||||
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
history: List[AgentHistoryItem] = []
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
action_manager: ActionManager
|
||||
observation_manager: ObservationManager
|
||||
reward_function: RewardFunction
|
||||
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""
|
||||
Configuration Schema for AbstractAgents.
|
||||
@@ -85,14 +84,14 @@ class AbstractAgent(BaseModel):
|
||||
variance: int = 0
|
||||
"The amount the frequency can randomly change to"
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_variance_lt_frequency(self) -> "AbstractAgent.ConfigSchema":
|
||||
"""
|
||||
Make sure variance is equal to or lower than frequency.
|
||||
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
|
||||
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
|
||||
This is because the calculation for the next execution time is now + (frequency +- variance).
|
||||
If variance were greater than frequency, sometimes the bracketed term would be negative
|
||||
and the attack would never happen again.
|
||||
"""
|
||||
if self.variance > self.frequency:
|
||||
raise ValueError(
|
||||
@@ -101,14 +100,12 @@ class AbstractAgent(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
@property
|
||||
def flatten_obs(self) -> bool:
|
||||
"""Return agent flatten_obs param."""
|
||||
@@ -117,7 +114,12 @@ class AbstractAgent(BaseModel):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
"""Creates an agent component from a configuration dictionary."""
|
||||
obj = cls(config=cls.ConfigSchema(**config))
|
||||
obj = cls(
|
||||
config=cls.ConfigSchema(**config["agent_settings"]),
|
||||
action_manager=ActionManager.from_config(**config["action_manager"]),
|
||||
observation_manager=ObservationManager.from_config(**config["observation_space"]),
|
||||
reward_function=RewardFunction.from_config(**config["reward_function"]),
|
||||
)
|
||||
return obj
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
@@ -206,9 +208,8 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
"""Configuration Schema for Proxy Agent."""
|
||||
|
||||
agent_name: str = "Proxy_Agent"
|
||||
agent_settings: AgentSettings = None
|
||||
flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
|
||||
action_masking: bool = agent_settings.action_masking if agent_settings else False
|
||||
flatten_obs: bool = False
|
||||
action_masking: bool = False
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
@@ -221,7 +222,7 @@ class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.config.action_manager.get_action(self.most_recent_action)
|
||||
return self.action_manager.get_action(self.most_recent_action)
|
||||
|
||||
def store_action(self, action: ActType):
|
||||
"""
|
||||
|
||||
@@ -7,10 +7,8 @@ import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
|
||||
from primaite.game.agent.rewards import SharedReward
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.network.creation import NetworkNodeAdder
|
||||
@@ -532,24 +530,14 @@ class PrimaiteGame:
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space = ActionManager.from_config(game, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
reward_function = RewardFunction.from_config(reward_function_cfg)
|
||||
agent_settings = agent_cfg["agent_settings"]
|
||||
|
||||
# CREATE AGENT
|
||||
|
||||
agent_settings = agent_cfg["agent_settings"]
|
||||
agent_config = {
|
||||
"agent_name": agent_ref,
|
||||
"action_manager": action_space,
|
||||
"observation_manager": obs_space,
|
||||
"reward_function": reward_function,
|
||||
"action_manager": action_space_cfg,
|
||||
"observation_manager": observation_space_cfg,
|
||||
"reward_function": reward_function_cfg,
|
||||
"agent_settings": agent_settings,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user