#2869 - Update Config for some agent classes to use pydantic.Field, amend some identifiers and agent_name variables
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from ray import init as rayinit
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
@@ -265,16 +266,16 @@ def example_network() -> Network:
|
||||
return network
|
||||
|
||||
|
||||
class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"):
|
||||
class ControlledAgent(AbstractAgent, identifier="ControlledAgent"):
|
||||
"""Agent that can be controlled by the tests."""
|
||||
|
||||
config: "ControlledAgent.ConfigSchema"
|
||||
config: "ControlledAgent.ConfigSchema" = Field(default_factory=lambda: ControlledAgent.ConfigSchema())
|
||||
most_recent_action: Optional[Tuple[str, Dict]] = None
|
||||
|
||||
class ConfigSchema(AbstractAgent.ConfigSchema):
|
||||
"""Configuration Schema for Abstract Agent used in tests."""
|
||||
|
||||
agent_name: str = "Controlled_Agent"
|
||||
type: str = "ControlledAgent"
|
||||
|
||||
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Return the agent's most recent action, formatted in CAOS format."""
|
||||
|
||||
Reference in New Issue
Block a user