#2913: Now with ConfigSchemas.

This commit is contained in:
Nick Todd
2024-10-17 16:35:13 +01:00
parent fe6a8e6e97
commit 419a86114d

View File

@@ -87,7 +87,7 @@ class AbstractReward(BaseModel):
class DummyReward(AbstractReward, identifier="DummyReward"):
"""Dummy reward function component which always returns 0.0"""
"""Dummy reward function component which always returns 0.0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -116,6 +116,13 @@ class DummyReward(AbstractReward, identifier="DummyReward"):
class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for DatabaseFileIntegrity."""
node_hostname: str
folder_name: str
file_name: str
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
@@ -186,6 +193,13 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebServer404Penalty."""
node_hostname: str
service_name: str
sticky: bool = True
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
"""Initialise the reward component.
@@ -259,6 +273,12 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
"""Penalises the agent when the web browser fails to fetch a webpage."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebpageUnavailablePenalty."""
node_hostname: str
sticky: bool = True
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
@@ -343,6 +363,12 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
"""Penalises the agent when the green db clients fail to connect to the database."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
node_hostname: str
sticky: bool = True
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
@@ -411,6 +437,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
class SharedReward(AbstractReward, identifier="SharedReward"):
"""Adds another agent's reward to the overall reward."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for SharedReward."""
agent_name: str
def __init__(self, agent_name: Optional[str] = None) -> None:
"""
Initialise the shared reward.
@@ -464,6 +495,12 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
"""Apply a negative reward when taking any action except DONOTHING."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for ActionPenalty."""
action_penalty: float = -1.0
do_nothing_penalty: float = 0.0
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
"""
Initialise the reward.