#2869 - Update Config for some agent classes to use pydantic.Field, amend some identifiers and agent_name variables

This commit is contained in:
Charlie Crane
2025-01-13 10:51:30 +00:00
parent 511abea59c
commit 32fc970cfe
3 changed files with 13 additions and 14 deletions

View File

@@ -6,7 +6,7 @@ from abc import abstractmethod
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.agent_log import AgentLog
@@ -50,7 +50,7 @@ class AbstractAgent(BaseModel):
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
history: List[AgentHistoryItem] = []
config: "AbstractAgent.ConfigSchema"
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
action_manager: "ActionManager"
observation_manager: "ObservationManager"
reward_function: "RewardFunction"
@@ -62,8 +62,6 @@ class AbstractAgent(BaseModel):
:param type: Type of agent being generated.
:type type: str
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
:type agent_name: str
:param observation_space: Observation space for the agent.
:type observation_space: Optional[ObservationSpace]
:param reward_function: Reward function for the agent.
@@ -73,7 +71,7 @@ class AbstractAgent(BaseModel):
"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
agent_name: str = "Abstract_Agent"
type: str = "AbstractAgent"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
@@ -185,15 +183,15 @@ class AbstractAgent(BaseModel):
self.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"):
"""Base class for actors which generate their own behaviour."""
config: "AbstractScriptedAgent.ConfigSchema"
config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for AbstractScriptedAgents."""
agent_name: str = "Abstract_Scripted_Agent"
type: str = "AbstractScriptedAgent"
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -204,13 +202,13 @@ class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent")
class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
"""Agent that sends observations to an RL model and receives actions from that model."""
config: "ProxyAgent.ConfigSchema"
config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema())
most_recent_action: ActType = None
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Proxy Agent."""
agent_name: str = "Proxy_Agent"
type: str = "Proxy_Agent"
flatten_obs: bool = False
action_masking: bool = False

View File

@@ -533,7 +533,7 @@ class PrimaiteGame:
agent_settings = agent_cfg["agent_settings"]
agent_config = {
"agent_name": agent_name,
"type": agent_type,
"action_manager": action_space_cfg,
"observation_manager": observation_space_cfg,
"reward_function": reward_function_cfg,

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple
import pytest
import yaml
from pydantic import Field
from ray import init as rayinit
from primaite import getLogger, PRIMAITE_PATHS
@@ -265,16 +266,16 @@ def example_network() -> Network:
return network
class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
class ControlledAgent(AbstractAgent, identifier="ControlledAgent"):
"""Agent that can be controlled by the tests."""
config: "ControlledAgent.ConfigSchema"
config: "ControlledAgent.ConfigSchema" = Field(default_factory=lambda: ControlledAgent.ConfigSchema())
most_recent_action: Optional[Tuple[str, Dict]] = None
class ConfigSchema(AbstractAgent.ConfigSchema):
"""Configuration Schema for Abstract Agent used in tests."""
agent_name: str = "Controlled_Agent"
type: str = "ControlledAgent"
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""