diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 370e6bbb..26445830 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -6,7 +6,7 @@ from abc import abstractmethod from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING from gymnasium.core import ActType, ObsType -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from primaite.game.agent.actions import ActionManager from primaite.game.agent.agent_log import AgentLog @@ -50,7 +50,7 @@ class AbstractAgent(BaseModel): logger: AgentLog = AgentLog(agent_name="Abstract_Agent") history: List[AgentHistoryItem] = [] - config: "AbstractAgent.ConfigSchema" + config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema()) action_manager: "ActionManager" observation_manager: "ObservationManager" reward_function: "RewardFunction" @@ -62,8 +62,6 @@ class AbstractAgent(BaseModel): :param type: Type of agent being generated. :type type: str - :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes. - :type agent_name: str :param observation_space: Observation space for the agent. :type observation_space: Optional[ObservationSpace] :param reward_function: Reward function for the agent. @@ -73,7 +71,7 @@ class AbstractAgent(BaseModel): """ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - agent_name: str = "Abstract_Agent" + type: str = "AbstractAgent" flatten_obs: bool = True "Whether to flatten the observation space before passing it to the agent. True by default." action_masking: bool = False @@ -185,15 +183,15 @@ class AbstractAgent(BaseModel): self.history[-1].reward = self.reward_function.current_reward -class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"): +class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"): """Base class for actors which generate their own behaviour.""" - config: "AbstractScriptedAgent.ConfigSchema" + config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema()) class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for AbstractScriptedAgents.""" - agent_name: str = "Abstract_Scripted_Agent" + type: str = "AbstractScriptedAgent" @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -204,13 +202,13 @@ class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent") class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): """Agent that sends observations to an RL model and receives actions from that model.""" - config: "ProxyAgent.ConfigSchema" + config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema()) most_recent_action: ActType = None class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for Proxy Agent.""" - agent_name: str = "Proxy_Agent" + type: str = "Proxy_Agent" flatten_obs: bool = False action_masking: bool = False diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 69e294ae..f2b1de4c 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -533,7 +533,7 @@ class PrimaiteGame: agent_settings = agent_cfg["agent_settings"] agent_config = { - "agent_name": agent_name, + "type": agent_type, "action_manager": action_space_cfg, "observation_manager": observation_space_cfg, "reward_function": reward_function_cfg, diff --git a/tests/conftest.py b/tests/conftest.py index 9d18a18b..b4b72e55 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple import pytest import yaml +from pydantic import Field from ray import init as rayinit from primaite import getLogger, PRIMAITE_PATHS @@ -265,16 +266,16 @@ def example_network() -> Network: return network -class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"): +class ControlledAgent(AbstractAgent, identifier="ControlledAgent"): """Agent that can be controlled by the tests.""" - config: "ControlledAgent.ConfigSchema" + config: "ControlledAgent.ConfigSchema" = Field(default_factory=lambda: ControlledAgent.ConfigSchema()) most_recent_action: Optional[Tuple[str, Dict]] = None class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for Abstract Agent used in tests.""" - agent_name: str = "Controlled_Agent" + type: str = "ControlledAgent" def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format."""