#2913: Initial commit of new AbstractReward class.

This commit is contained in:
Nick Todd
2024-10-17 13:24:57 +01:00
parent 6844bd692a
commit fe6a8e6e97

View File

@@ -27,9 +27,10 @@ the structure:
service_ref: web_server_database_client
```
"""
from abc import abstractmethod
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from pydantic import BaseModel
from typing_extensions import Never
from primaite import getLogger
@@ -42,9 +43,35 @@ _LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
class AbstractReward:
class AbstractReward(BaseModel):
"""Base class for reward function components."""
class ConfigSchema(BaseModel, ABC):
"""Config schema for AbstractReward."""
type: str
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate node adder {identifier}")
cls._registry[identifier] = cls
@classmethod
@abstractmethod
def from_config(cls, config: Dict) -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: dict
:return: The reward component.
:rtype: AbstractReward
"""
if config["type"] not in cls._registry:
raise ValueError(f"Invalid reward type {config['type']}")
return cls
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -58,21 +85,9 @@ class AbstractReward:
"""
return 0.0
@classmethod
@abstractmethod
def from_config(cls, config: dict) -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: dict
:return: The reward component.
:rtype: AbstractReward
"""
return cls()
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
class DummyReward(AbstractReward, identifier="DummyReward"):
"""Dummy reward function component which always returns 0.0"""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -98,7 +113,7 @@ class DummyReward(AbstractReward):
return cls()
class DatabaseFileIntegrity(AbstractReward):
class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
@@ -168,7 +183,7 @@ class DatabaseFileIntegrity(AbstractReward):
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
@@ -241,7 +256,7 @@ class WebServer404Penalty(AbstractReward):
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
class WebpageUnavailablePenalty(AbstractReward):
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
"""Penalises the agent when the web browser fails to fetch a webpage."""
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
@@ -325,7 +340,7 @@ class WebpageUnavailablePenalty(AbstractReward):
return cls(node_hostname=node_hostname, sticky=sticky)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
@@ -393,7 +408,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
return cls(node_hostname=node_hostname, sticky=sticky)
class SharedReward(AbstractReward):
class SharedReward(AbstractReward, identifier="SharedReward"):
"""Adds another agent's reward to the overall reward."""
def __init__(self, agent_name: Optional[str] = None) -> None:
@@ -446,7 +461,7 @@ class SharedReward(AbstractReward):
return cls(agent_name=agent_name)
class ActionPenalty(AbstractReward):
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
"""Apply a negative reward when taking any action except DONOTHING."""
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: