From fe6a8e6e97cb7a3837091950d7038b52fe1ffbff Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 17 Oct 2024 13:24:57 +0100 Subject: [PATCH] #2913: Initial commit of new AbstractReward class. --- src/primaite/game/agent/rewards.py | 61 +++++++++++++++++++----------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 1de34b40..0db3cc28 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -27,9 +27,10 @@ the structure: service_ref: web_server_database_client ``` """ -from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from pydantic import BaseModel from typing_extensions import Never from primaite import getLogger @@ -42,9 +43,35 @@ _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] -class AbstractReward: +class AbstractReward(BaseModel): """Base class for reward function components.""" + class ConfigSchema(BaseModel, ABC): + """Config schema for AbstractReward.""" + + type: str + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate node adder {identifier}") + cls._registry[identifier] = cls + + @classmethod + @abstractmethod + def from_config(cls, config: Dict) -> "AbstractReward": + """Create a reward function component from a config dictionary. + + :param config: dict of options for the reward component's constructor + :type config: dict + :return: The reward component. + :rtype: AbstractReward + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid reward type {config['type']}") + + return cls + @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -58,21 +85,9 @@ class AbstractReward: """ return 0.0 - @classmethod - @abstractmethod - def from_config(cls, config: dict) -> "AbstractReward": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: dict - :return: The reward component. - :rtype: AbstractReward - """ - return cls() - - -class DummyReward(AbstractReward): - """Dummy reward function component which always returns 0.""" +class DummyReward(AbstractReward, identifier="DummyReward"): + """Dummy reward function component which always returns 0.0""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -98,7 +113,7 @@ class DummyReward(AbstractReward): return cls() -class DatabaseFileIntegrity(AbstractReward): +class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: @@ -168,7 +183,7 @@ class DatabaseFileIntegrity(AbstractReward): return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name) -class WebServer404Penalty(AbstractReward): +class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: @@ -241,7 +256,7 @@ class WebServer404Penalty(AbstractReward): return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) -class WebpageUnavailablePenalty(AbstractReward): +class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" def __init__(self, node_hostname: str, sticky: bool = True) -> None: @@ -325,7 +340,7 @@ class WebpageUnavailablePenalty(AbstractReward): return cls(node_hostname=node_hostname, sticky=sticky) -class GreenAdminDatabaseUnreachablePenalty(AbstractReward): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" def __init__(self, node_hostname: str, sticky: bool = True) -> None: @@ -393,7 +408,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return cls(node_hostname=node_hostname, sticky=sticky) -class SharedReward(AbstractReward): +class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" def __init__(self, agent_name: Optional[str] = None) -> None: @@ -446,7 +461,7 @@ class SharedReward(AbstractReward): return cls(agent_name=agent_name) -class ActionPenalty(AbstractReward): +class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: