#2869 - Make agent schema children work properly

This commit is contained in:
Marek Wolan
2025-01-15 16:44:17 +00:00
parent b4b6c16872
commit f8fb052dad
5 changed files with 144 additions and 38 deletions

View File

@@ -39,6 +39,13 @@ class DoNothingAction(AbstractAction, identifier="do_nothing"):
return ["do_nothing"]
class _ActionMapItem(BaseModel):
model_config = ConfigDict(extra="forbid")
action: str
options: Dict
class ActionManager(BaseModel):
"""Class which manages the action space for an agent."""
@@ -46,20 +53,23 @@ class ActionManager(BaseModel):
"""Config Schema for ActionManager."""
model_config = ConfigDict(extra="forbid")
action_map: Dict[int, Tuple[str, Dict]] = {}
action_map: Dict[int, _ActionMapItem] = {}
"""Mapping between integer action choices and CAOS actions."""
@field_validator("action_map", mode="after")
def consecutive_action_nums(cls, v: Dict) -> Dict:
"""Make sure all numbers between 0 and N are represented as dict keys in action map."""
assert all([i in v.keys() for i in range(len(v))])
return v
config: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
@property
def action_map(self) -> Dict[int, Tuple[str, Dict]]:
"""Convenience method for accessing the action map."""
return self.config.action_map
action_map: Dict[int, Tuple[str, Dict]] = {}
"""Init as empty, populate after model validation."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.action_map = {n: (v.action, v.options) for n, v in self.config.action_map.items()}
def get_action(self, action: int) -> Tuple[str, Dict]:
"""

View File

@@ -52,7 +52,7 @@ class AbstractAgent(BaseModel, ABC):
"""Configuration Schema for AbstractAgents."""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
type: str = "AbstractAgent"
type: str
action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema())
observation_space: ObservationManager.ConfigSchema = Field(
default_factory=lambda: ObservationManager.ConfigSchema()
@@ -63,9 +63,10 @@ class AbstractAgent(BaseModel, ABC):
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
history: List[AgentHistoryItem] = []
action_manager: "ActionManager"
observation_manager: "ObservationManager"
reward_function: "RewardFunction"
action_manager: ActionManager = Field(default_factory=lambda: ActionManager())
observation_manager: ObservationManager = Field(default_factory=lambda: ObservationManager())
reward_function: RewardFunction = Field(default_factory=lambda: RewardFunction())
_registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {}
@@ -77,32 +78,18 @@ class AbstractAgent(BaseModel, ABC):
raise ValueError(f"Cannot create a new agent under reserved name {identifier}")
cls._registry[identifier] = cls
def __init__(self, config: ConfigSchema, **kwargs):
kwargs["action_manager"] = kwargs.get("action_manager") or ActionManager.from_config(config.action_space)
kwargs["observation_manager"] = kwargs.get("observation_manager") or ObservationManager(
config.observation_space
)
kwargs["reward_function"] = kwargs.get("reward_function") or RewardFunction.from_config(config.reward_function)
super().__init__(config=config, **kwargs)
def model_post_init(self, __context: Any) -> None:
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
self.action_manager = ActionManager(config=self.config.action_space)
self.observation_manager = ObservationManager(config=self.config.observation_space)
self.reward_function = RewardFunction(config=self.config.reward_function)
return super().model_post_init(__context)
@property
def flatten_obs(self) -> bool:
"""Return agent flatten_obs param."""
return self.config.flatten_obs
@classmethod
def from_config(cls, config: Dict) -> "AbstractAgent":
"""Creates an agent component from a configuration dictionary."""
if config["type"] not in cls._registry:
return ValueError(f"Invalid Agent Type: {config['type']}")
obj = cls(
config=cls.ConfigSchema(**config["agent_settings"]),
action_manager=ActionManager.from_config(config["action_space"]),
observation_manager=ObservationManager(config["observation_space"]),
reward_function=RewardFunction.from_config(config["reward_function"]),
)
return obj
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.

View File

@@ -194,7 +194,7 @@ class ObservationManager(BaseModel):
if "options" not in data:
data["options"] = obs_class.ConfigSchema()
# if options passed as a dict, convert to a schema
# if options passed as a dict, validate against schema
elif isinstance(data["options"], dict):
data["options"] = obs_class.ConfigSchema(**data["options"])

View File

@@ -30,7 +30,7 @@ the structure:
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Never
from primaite import getLogger
@@ -410,15 +410,74 @@ class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
return self.config.action_penalty
class RewardFunction:
class _SingleComponentConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
type: str
options: AbstractReward.ConfigSchema
weight: float = 1.0
@model_validator(mode="before")
@classmethod
def resolve_obs_options_type(cls, data: Any) -> Any:
"""
When constructing the model from a dict, resolve the correct reward class based on `type` field.
Workaround: The `options` field is statically typed as AbstractReward. Therefore, it falls over when
passing in data that adheres to a subclass schema rather than the plain AbstractReward schema. There is
a way to do this properly using discriminated union, but most advice on the internet assumes that the full
list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of
our plugin architecture.
We may be able to revisit and implement a better solution when needed using the following resources as
research starting points:
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
https://github.com/pydantic/pydantic/issues/7366
https://github.com/pydantic/pydantic/issues/7462
https://github.com/pydantic/pydantic/pull/7983
"""
if not isinstance(data, dict):
return data
assert "type" in data, ValueError('Reward component definition is missing the "type" key.')
rew_type = data["type"]
rew_class = AbstractReward._registry[rew_type]
# if no options are passed in, try to create a default schema. Only works if there are no mandatory fields.
if "options" not in data:
data["options"] = rew_class.ConfigSchema()
# if options are passed as a dict, validate against schema
elif isinstance(data["options"], dict):
data["options"] = rew_class.ConfigSchema(**data["options"])
return data
class RewardFunction(BaseModel):
"""Manages the reward function for the agent."""
def __init__(self):
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float = 0.0
self.total_reward: float = 0.0
model_config = ConfigDict(extra="forbid")
class ConfigSchema(BaseModel):
"""Config Schema for RewardFunction."""
model_config = ConfigDict(extra="forbid")
reward_components: Iterable[_SingleComponentConfig] = []
config: ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
reward_components: List[Tuple[AbstractReward, float]] = []
current_reward: float = 0.0
total_reward: float = 0.0
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
for rew_config in self.config.reward_components:
rew_class = AbstractReward._registry[rew_config.type]
rew_instance = rew_class(config=rew_config.options)
self.register_component(component=rew_instance, weight=rew_config.weight)
def register_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.

View File

@@ -0,0 +1,50 @@
from primaite.game.agent.observations.file_system_observations import FileObservation
from primaite.game.agent.observations.observation_manager import NullObservation
from primaite.game.agent.scripted_agents.random_agent import RandomAgent
def test_creating_empty_agent():
agent = RandomAgent()
assert len(agent.action_manager.action_map) == 0
assert isinstance(agent.observation_manager.obs, NullObservation)
assert len(agent.reward_function.reward_components) == 0
def test_creating_agent_from_dict():
action_config = {
"action_map": {
0: {"action": "do_nothing", "options": {}},
1: {
"action": "node_application_execute",
"options": {"node_name": "client", "application_name": "database"},
},
}
}
observation_config = {
"type": "FILE",
"options": {
"file_name": "dog.pdf",
"include_num_access": False,
"file_system_requires_scan": False,
},
}
reward_config = {
"reward_components": [
{
"type": "DATABASE_FILE_INTEGRITY",
"weight": 0.3,
"options": {"node_hostname": "server", "folder_name": "database", "file_name": "database.db"},
}
]
}
agent = RandomAgent(
config={
"action_space": action_config,
"observation_space": observation_config,
"reward_function": reward_config,
}
)
assert len(agent.action_manager.action_map) == 2
assert isinstance(agent.observation_manager.obs, FileObservation)
assert len(agent.reward_function.reward_components) == 1