#2869 - Update Config for some agent classes to use pydantic.Field, amend some identifiers and agent_name variables
This commit is contained in:
@@ -6,7 +6,7 @@ from abc import abstractmethod
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.agent_log import AgentLog
|
||||
@@ -50,7 +50,7 @@ class AbstractAgent(BaseModel):
|
||||
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
|
||||
|
||||
history: List[AgentHistoryItem] = []
|
||||
config: "AbstractAgent.ConfigSchema"
|
||||
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
|
||||
action_manager: "ActionManager"
|
||||
observation_manager: "ObservationManager"
|
||||
reward_function: "RewardFunction"
|
||||
@@ -62,8 +62,6 @@ class AbstractAgent(BaseModel):
|
||||
|
||||
:param type: Type of agent being generated.
|
||||
:type type: str
|
||||
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
:type agent_name: str
|
||||
:param observation_space: Observation space for the agent.
|
||||
:type observation_space: Optional[ObservationSpace]
|
||||
:param reward_function: Reward function for the agent.
|
||||
@@ -73,7 +71,7 @@ class AbstractAgent(BaseModel):
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
agent_name: str = "Abstract_Agent"
|
||||
type: str = "AbstractAgent"
|
||||
flatten_obs: bool = True
|
||||
"Whether to flatten the observation space before passing it to the agent. True by default."
|
||||
action_masking: bool = False
|
||||
@@ -185,15 +183,15 @@ class AbstractAgent(BaseModel):
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"):
|
||||
class AbstractScriptedAgent(AbstractAgent, identifier="AbstractScriptedAgent"):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
config: "AbstractScriptedAgent.ConfigSchema"
|
||||
config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for AbstractScriptedAgents."""
|
||||
|
||||
agent_name: str = "Abstract_Scripted_Agent"
|
||||
type: str = "AbstractScriptedAgent"
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -204,13 +202,13 @@ class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent")
|
||||
class ProxyAgent(AbstractAgent, identifier="ProxyAgent"):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
config: "ProxyAgent.ConfigSchema"
|
||||
config: "ProxyAgent.ConfigSchema" = Field(default_factory=lambda: ProxyAgent.ConfigSchema())
|
||||
most_recent_action: ActType = None
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Proxy Agent."""
|
||||
|
||||
agent_name: str = "Proxy_Agent"
|
||||
type: str = "Proxy_Agent"
|
||||
flatten_obs: bool = False
|
||||
action_masking: bool = False
|
||||
|
||||
|
||||
@@ -533,7 +533,7 @@ class PrimaiteGame:
|
||||
agent_settings = agent_cfg["agent_settings"]
|
||||
|
||||
agent_config = {
|
||||
"agent_name": agent_name,
|
||||
"type": agent_type,
|
||||
"action_manager": action_space_cfg,
|
||||
"observation_manager": observation_space_cfg,
|
||||
"reward_function": reward_function_cfg,
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from ray import init as rayinit
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
@@ -265,16 +266,16 @@ def example_network() -> Network:
|
||||
return network
|
||||
|
||||
|
||||
class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
|
||||
class ControlledAgent(AbstractAgent, identifier="ControlledAgent"):
|
||||
"""Agent that can be controlled by the tests."""
|
||||
|
||||
config: "ControlledAgent.ConfigSchema"
|
||||
config: "ControlledAgent.ConfigSchema" = Field(default_factory=lambda: ControlledAgent.ConfigSchema())
|
||||
most_recent_action: Optional[Tuple[str, Dict]] = None
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Abstract Agent used in tests."""
|
||||
|
||||
agent_name: str = "Controlled_Agent"
|
||||
type: str = "ControlledAgent"
|
||||
|
||||
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return the agent's most recent action, formatted in CAOS format."""
|
||||
|
||||
Reference in New Issue
Block a user