#2869 - Make agent schema children work properly
This commit is contained in:
@@ -39,6 +39,13 @@ class DoNothingAction(AbstractAction, identifier="do_nothing"):
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class _ActionMapItem(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
action: str
|
||||
options: Dict
|
||||
|
||||
|
||||
class ActionManager(BaseModel):
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -46,20 +53,23 @@ class ActionManager(BaseModel):
|
||||
"""Config Schema for ActionManager."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
action_map: Dict[int, _ActionMapItem] = {}
|
||||
"""Mapping between integer action choices and CAOS actions."""
|
||||
|
||||
@field_validator("action_map", mode="after")
|
||||
def consecutive_action_nums(cls, v: Dict) -> Dict:
|
||||
"""Make sure all numbers between 0 and N are represented as dict keys in action map."""
|
||||
assert all([i in v.keys() for i in range(len(v))])
|
||||
return v
|
||||
|
||||
config: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
|
||||
|
||||
@property
|
||||
def action_map(self) -> Dict[int, Tuple[str, Dict]]:
|
||||
"""Convenience method for accessing the action map."""
|
||||
return self.config.action_map
|
||||
action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""Init as empty, populate after model validation."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.action_map = {n: (v.action, v.options) for n, v in self.config.action_map.items()}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -52,7 +52,7 @@ class AbstractAgent(BaseModel, ABC):
|
||||
"""Configuration Schema for AbstractAgents."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
type: str = "AbstractAgent"
|
||||
type: str
|
||||
action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
|
||||
observation_space: ObservationManager.ConfigSchema = Field(
|
||||
default_factory=lambda: ObservationManager.ConfigSchema()
|
||||
@@ -63,9 +63,10 @@ class AbstractAgent(BaseModel, ABC):
|
||||
|
||||
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
|
||||
history: List[AgentHistoryItem] = []
|
||||
action_manager: "ActionManager"
|
||||
observation_manager: "ObservationManager"
|
||||
reward_function: "RewardFunction"
|
||||
|
||||
action_manager: ActionManager = Field(default_factory=lambda: ActionManager())
|
||||
observation_manager: ObservationManager = Field(default_factory=lambda: ObservationManager())
|
||||
reward_function: RewardFunction = Field(default_factory=lambda: RewardFunction())
|
||||
|
||||
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
|
||||
|
||||
@@ -77,32 +78,18 @@ class AbstractAgent(BaseModel, ABC):
|
||||
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
def __init__(self, config: ConfigSchema, **kwargs):
|
||||
kwargs["action_manager"] = kwargs.get("action_manager") or ActionManager.from_config(config.action_space)
|
||||
kwargs["observation_manager"] = kwargs.get("observation_manager") or ObservationManager(
|
||||
config.observation_space
|
||||
)
|
||||
kwargs["reward_function"] = kwargs.get("reward_function") or RewardFunction.from_config(config.reward_function)
|
||||
super().__init__(config=config, **kwargs)
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
|
||||
self.action_manager = ActionManager(config=self.config.action_space)
|
||||
self.observation_manager = ObservationManager(config=self.config.observation_space)
|
||||
self.reward_function = RewardFunction(config=self.config.reward_function)
|
||||
return super().model_post_init(__context)
|
||||
|
||||
@property
|
||||
def flatten_obs(self) -> bool:
|
||||
"""Return agent flatten_obs param."""
|
||||
return self.config.flatten_obs
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractAgent":
|
||||
"""Creates an agent component from a configuration dictionary."""
|
||||
if config["type"] not in cls._registry:
|
||||
return ValueError(f"Invalid Agent Type: {config['type']}")
|
||||
obj = cls(
|
||||
config=cls.ConfigSchema(**config["agent_settings"]),
|
||||
action_manager=ActionManager.from_config(config["action_space"]),
|
||||
observation_manager=ObservationManager(config["observation_space"]),
|
||||
reward_function=RewardFunction.from_config(config["reward_function"]),
|
||||
)
|
||||
return obj
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
@@ -194,7 +194,7 @@ class ObservationManager(BaseModel):
|
||||
if "options" not in data:
|
||||
data["options"] = obs_class.ConfigSchema()
|
||||
|
||||
# if options passed as a dict, convert to a schema
|
||||
# if options passed as a dict, validate against schema
|
||||
elif isinstance(data["options"], dict):
|
||||
data["options"] = obs_class.ConfigSchema(**data["options"])
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ the structure:
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -410,15 +410,74 @@ class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
return self.config.action_penalty
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
class _SingleComponentConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
options: AbstractReward.ConfigSchema
|
||||
weight: float = 1.0
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def resolve_obs_options_type(cls, data: Any) -> Any:
|
||||
"""
|
||||
When constructing the model from a dict, resolve the correct reward class based on `type` field.
|
||||
|
||||
Workaround: The `options` field is statically typed as AbstractReward. Therefore, it falls over when
|
||||
passing in data that adheres to a subclass schema rather than the plain AbstractReward schema. There is
|
||||
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
|
||||
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
|
||||
our plugin architecture.
|
||||
|
||||
We may be able to revisit and implement a better solution when needed using the following resources as
|
||||
research starting points:
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
https://github.com/pydantic/pydantic/issues/7366
|
||||
https://github.com/pydantic/pydantic/issues/7462
|
||||
https://github.com/pydantic/pydantic/pull/7983
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
assert "type" in data, ValueError('Reward component definition is missing the "type" key.')
|
||||
rew_type = data["type"]
|
||||
rew_class = AbstractReward._registry[rew_type]
|
||||
|
||||
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields.
|
||||
if "options" not in data:
|
||||
data["options"] = rew_class.ConfigSchema()
|
||||
|
||||
# if options are passed as a dict, validate against schema
|
||||
elif isinstance(data["options"], dict):
|
||||
data["options"] = rew_class.ConfigSchema(**data["options"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class RewardFunction(BaseModel):
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
self.current_reward: float = 0.0
|
||||
self.total_reward: float = 0.0
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
"""Config Schema for RewardFunction."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
reward_components: Iterable[_SingleComponentConfig] = []
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
|
||||
reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
|
||||
current_reward: float = 0.0
|
||||
total_reward: float = 0.0
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
for rew_config in self.config.reward_components:
|
||||
rew_class = AbstractReward._registry[rew_config.type]
|
||||
rew_instance = rew_class(config=rew_config.options)
|
||||
self.register_component(component=rew_instance, weight=rew_config.weight)
|
||||
|
||||
def register_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
|
||||
50
tests/unit_tests/_primaite/_game/_agent/test_agent.py
Normal file
50
tests/unit_tests/_primaite/_game/_agent/test_agent.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation
|
||||
from primaite.game.agent.observations.observation_manager import NullObservation
|
||||
from primaite.game.agent.scripted_agents.random_agent import RandomAgent
|
||||
|
||||
|
||||
def test_creating_empty_agent():
|
||||
agent = RandomAgent()
|
||||
assert len(agent.action_manager.action_map) == 0
|
||||
assert isinstance(agent.observation_manager.obs, NullObservation)
|
||||
assert len(agent.reward_function.reward_components) == 0
|
||||
|
||||
|
||||
def test_creating_agent_from_dict():
|
||||
action_config = {
|
||||
"action_map": {
|
||||
0: {"action": "do_nothing", "options": {}},
|
||||
1: {
|
||||
"action": "node_application_execute",
|
||||
"options": {"node_name": "client", "application_name": "database"},
|
||||
},
|
||||
}
|
||||
}
|
||||
observation_config = {
|
||||
"type": "FILE",
|
||||
"options": {
|
||||
"file_name": "dog.pdf",
|
||||
"include_num_access": False,
|
||||
"file_system_requires_scan": False,
|
||||
},
|
||||
}
|
||||
reward_config = {
|
||||
"reward_components": [
|
||||
{
|
||||
"type": "DATABASE_FILE_INTEGRITY",
|
||||
"weight": 0.3,
|
||||
"options": {"node_hostname": "server", "folder_name": "database", "file_name": "database.db"},
|
||||
}
|
||||
]
|
||||
}
|
||||
agent = RandomAgent(
|
||||
config={
|
||||
"action_space": action_config,
|
||||
"observation_space": observation_config,
|
||||
"reward_function": reward_config,
|
||||
}
|
||||
)
|
||||
|
||||
assert len(agent.action_manager.action_map) == 2
|
||||
assert isinstance(agent.observation_manager.obs, FileObservation)
|
||||
assert len(agent.reward_function.reward_components) == 1
|
||||
Reference in New Issue
Block a user