From 419a86114d0b6eca4d2da343c427ddddf1b297cf Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 17 Oct 2024 16:35:13 +0100 Subject: [PATCH] #2913: Now with ConfigSchemas. --- src/primaite/game/agent/rewards.py | 39 +++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 0db3cc28..4198af27 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -87,7 +87,7 @@ class AbstractReward(BaseModel): class DummyReward(AbstractReward, identifier="DummyReward"): - """Dummy reward function component which always returns 0.0""" + """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -116,6 +116,13 @@ class DummyReward(AbstractReward, identifier="DummyReward"): class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + node_hostname: str + folder_name: str + file_name: str + def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: """Initialise the reward component. @@ -186,6 +193,13 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebServer404Penalty.""" + + node_hostname: str + service_name: str + sticky: bool = True + def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: """Initialise the reward component. @@ -259,6 +273,12 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebpageUnavailablePenalty.""" + + node_hostname: str + sticky: bool = True + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. @@ -343,6 +363,12 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" + + node_hostname: str + sticky: bool = True + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. @@ -411,6 +437,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for SharedReward.""" + + agent_name: str + def __init__(self, agent_name: Optional[str] = None) -> None: """ Initialise the shared reward. @@ -464,6 +495,12 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for ActionPenalty.""" + + action_penalty: float = -1.0 + do_nothing_penalty: float = 0.0 + def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: """ Initialise the reward.