#2913: Initial commit of new AbstractReward class.
This commit is contained in:
@@ -27,9 +27,10 @@ the structure:
|
||||
service_ref: web_server_database_client
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -42,9 +43,35 @@ _LOGGER = getLogger(__name__)
|
||||
WhereType = Optional[Iterable[Union[str, int]]]
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
class AbstractReward(BaseModel):
|
||||
"""Base class for reward function components."""
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Config schema for AbstractReward."""
|
||||
|
||||
type: str
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate node adder {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: dict
|
||||
:return: The reward component.
|
||||
:rtype: AbstractReward
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid reward type {config['type']}")
|
||||
|
||||
return cls
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
@@ -58,21 +85,9 @@ class AbstractReward:
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict) -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: dict
|
||||
:return: The reward component.
|
||||
:rtype: AbstractReward
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
class DummyReward(AbstractReward, identifier="DummyReward"):
|
||||
"""Dummy reward function component which always returns 0.0"""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
@@ -98,7 +113,7 @@ class DummyReward(AbstractReward):
|
||||
return cls()
|
||||
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward):
|
||||
class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
|
||||
@@ -168,7 +183,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
|
||||
@@ -241,7 +256,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
|
||||
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward):
|
||||
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
@@ -325,7 +340,7 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
@@ -393,7 +408,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward):
|
||||
class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
@@ -446,7 +461,7 @@ class SharedReward(AbstractReward):
|
||||
return cls(agent_name=agent_name)
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward):
|
||||
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
|
||||
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
|
||||
|
||||
Reference in New Issue
Block a user