From f8fb052dadfa3f6a4d43376c281e91f587b8c30c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 15 Jan 2025 16:44:17 +0000 Subject: [PATCH] #2869 - Make agent schema children work properly --- src/primaite/game/agent/actions/manager.py | 20 +++-- src/primaite/game/agent/interface.py | 35 +++------ .../agent/observations/observation_manager.py | 2 +- src/primaite/game/agent/rewards.py | 75 +++++++++++++++++-- .../_primaite/_game/_agent/test_agent.py | 50 +++++++++++++ 5 files changed, 144 insertions(+), 38 deletions(-) create mode 100644 tests/unit_tests/_primaite/_game/_agent/test_agent.py diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index 0f7db2f3..a6e235c5 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -39,6 +39,13 @@ class DoNothingAction(AbstractAction, identifier="do_nothing"): return ["do_nothing"] +class _ActionMapItem(BaseModel): + model_config = ConfigDict(extra="forbid") + + action: str + options: Dict + + class ActionManager(BaseModel): """Class which manages the action space for an agent.""" @@ -46,20 +53,23 @@ class ActionManager(BaseModel): """Config Schema for ActionManager.""" model_config = ConfigDict(extra="forbid") - action_map: Dict[int, Tuple[str, Dict]] = {} + action_map: Dict[int, _ActionMapItem] = {} """Mapping between integer action choices and CAOS actions.""" @field_validator("action_map", mode="after") def consecutive_action_nums(cls, v: Dict) -> Dict: """Make sure all numbers between 0 and N are represented as dict keys in action map.""" assert all([i in v.keys() for i in range(len(v))]) + return v config: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema()) - @property - def action_map(self) -> Dict[int, Tuple[str, Dict]]: - """Convenience method for accessing the action map.""" - return self.config.action_map + action_map: Dict[int, Tuple[str, Dict]] = {} + """Init as empty, populate after model validation.""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.action_map = {n: (v.action, v.options) for n, v in self.config.action_map.items()} def get_action(self, action: int) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 0b55c1db..3311de66 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -52,7 +52,7 @@ class AbstractAgent(BaseModel, ABC): """Configuration Schema for AbstractAgents.""" model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - type: str = "AbstractAgent" + type: str action_space: ActionManager.ConfigSchema = Field(default_factory=lambda: ActionManager.ConfigSchema()) observation_space: ObservationManager.ConfigSchema = Field( default_factory=lambda: ObservationManager.ConfigSchema() @@ -63,9 +63,10 @@ class AbstractAgent(BaseModel, ABC): logger: AgentLog = AgentLog(agent_name="Abstract_Agent") history: List[AgentHistoryItem] = [] - action_manager: "ActionManager" - observation_manager: "ObservationManager" - reward_function: "RewardFunction" + + action_manager: ActionManager = Field(default_factory=lambda: ActionManager()) + observation_manager: ObservationManager = Field(default_factory=lambda: ObservationManager()) + reward_function: RewardFunction = Field(default_factory=lambda: RewardFunction()) _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} @@ -77,32 +78,18 @@ class AbstractAgent(BaseModel, ABC): raise ValueError(f"Cannot create a new agent under reserved name {identifier}") cls._registry[identifier] = cls - def __init__(self, config: ConfigSchema, **kwargs): - kwargs["action_manager"] = kwargs.get("action_manager") or ActionManager.from_config(config.action_space) - kwargs["observation_manager"] = kwargs.get("observation_manager") or ObservationManager( - config.observation_space - ) - kwargs["reward_function"] = kwargs.get("reward_function") or RewardFunction.from_config(config.reward_function) - super().__init__(config=config, **kwargs) + def model_post_init(self, __context: Any) -> None: + """Overwrite the default empty action, observation, and rewards with ones defined through the config.""" + self.action_manager = ActionManager(config=self.config.action_space) + self.observation_manager = ObservationManager(config=self.config.observation_space) + self.reward_function = RewardFunction(config=self.config.reward_function) + return super().model_post_init(__context) @property def flatten_obs(self) -> bool: """Return agent flatten_obs param.""" return self.config.flatten_obs - @classmethod - def from_config(cls, config: Dict) -> "AbstractAgent": - """Creates an agent component from a configuration dictionary.""" - if config["type"] not in cls._registry: - return ValueError(f"Invalid Agent Type: {config['type']}") - obj = cls( - config=cls.ConfigSchema(**config["agent_settings"]), - action_manager=ActionManager.from_config(config["action_space"]), - observation_manager=ObservationManager(config["observation_space"]), - reward_function=RewardFunction.from_config(config["reward_function"]), - ) - return obj - def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 6964ce2c..83d4a076 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -194,7 +194,7 @@ class ObservationManager(BaseModel): if "options" not in data: data["options"] = obs_class.ConfigSchema() - # if options passed as a dict, convert to a schema + # if options passed as a dict, validate against schema elif isinstance(data["options"], dict): data["options"] = obs_class.ConfigSchema(**data["options"]) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 50fdaba8..d4c8ef9b 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -30,7 +30,7 @@ the structure: from abc import ABC, abstractmethod from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Never from primaite import getLogger @@ -410,15 +410,74 @@ class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"): return self.config.action_penalty -class RewardFunction: +class _SingleComponentConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + type: str + options: AbstractReward.ConfigSchema + weight: float = 1.0 + + @model_validator(mode="before") + @classmethod + def resolve_obs_options_type(cls, data: Any) -> Any: + """ + When constructing the model from a dict, resolve the correct reward class based on `type` field. + + Workaround: The `options` field is statically typed as AbstractReward. Therefore, it falls over when + passing in data that adheres to a subclass schema rather than the plain AbstractReward schema. There is + a way to do this properly using discriminated union, but most advice on the internet assumes that the full + list of types between which to discriminate is known ahead-of-time. That is not the case for us, because of + our plugin architecture. + + We may be able to revisit and implement a better solution when needed using the following resources as + research starting points: + https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions + https://github.com/pydantic/pydantic/issues/7366 + https://github.com/pydantic/pydantic/issues/7462 + https://github.com/pydantic/pydantic/pull/7983 + """ + if not isinstance(data, dict): + return data + + assert "type" in data, ValueError('Reward component definition is missing the "type" key.') + rew_type = data["type"] + rew_class = AbstractReward._registry[rew_type] + + # if no options are passed in, try to create a default schema. Only works if there are no mandatory fields. + if "options" not in data: + data["options"] = rew_class.ConfigSchema() + + # if options are passed as a dict, validate against schema + elif isinstance(data["options"], dict): + data["options"] = rew_class.ConfigSchema(**data["options"]) + + return data + + +class RewardFunction(BaseModel): """Manages the reward function for the agent.""" - def __init__(self): - """Initialise the reward function object.""" - self.reward_components: List[Tuple[AbstractReward, float]] = [] - "attribute reward_components keeps track of reward components and the weights assigned to each." - self.current_reward: float = 0.0 - self.total_reward: float = 0.0 + model_config = ConfigDict(extra="forbid") + + class ConfigSchema(BaseModel): + """Config Schema for RewardFunction.""" + + model_config = ConfigDict(extra="forbid") + + reward_components: Iterable[_SingleComponentConfig] = [] + + config: ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) + + reward_components: List[Tuple[AbstractReward, float]] = [] + + current_reward: float = 0.0 + total_reward: float = 0.0 + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + for rew_config in self.config.reward_components: + rew_class = AbstractReward._registry[rew_config.type] + rew_instance = rew_class(config=rew_config.options) + self.register_component(component=rew_instance, weight=rew_config.weight) def register_component(self, component: AbstractReward, weight: float = 1.0) -> None: """Add a reward component to the reward function. diff --git a/tests/unit_tests/_primaite/_game/_agent/test_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_agent.py new file mode 100644 index 00000000..5f3b4fc0 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_agent.py @@ -0,0 +1,50 @@ +from primaite.game.agent.observations.file_system_observations import FileObservation +from primaite.game.agent.observations.observation_manager import NullObservation +from primaite.game.agent.scripted_agents.random_agent import RandomAgent + + +def test_creating_empty_agent(): + agent = RandomAgent() + assert len(agent.action_manager.action_map) == 0 + assert isinstance(agent.observation_manager.obs, NullObservation) + assert len(agent.reward_function.reward_components) == 0 + + +def test_creating_agent_from_dict(): + action_config = { + "action_map": { + 0: {"action": "do_nothing", "options": {}}, + 1: { + "action": "node_application_execute", + "options": {"node_name": "client", "application_name": "database"}, + }, + } + } + observation_config = { + "type": "FILE", + "options": { + "file_name": "dog.pdf", + "include_num_access": False, + "file_system_requires_scan": False, + }, + } + reward_config = { + "reward_components": [ + { + "type": "DATABASE_FILE_INTEGRITY", + "weight": 0.3, + "options": {"node_hostname": "server", "folder_name": "database", "file_name": "database.db"}, + } + ] + } + agent = RandomAgent( + config={ + "action_space": action_config, + "observation_space": observation_config, + "reward_function": reward_config, + } + ) + + assert len(agent.action_manager.action_map) == 2 + assert isinstance(agent.observation_manager.obs, FileObservation) + assert len(agent.reward_function.reward_components) == 1