From bbcbb26f5edd79202bfed16b46a7bc3573b60397 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 21 Oct 2024 14:43:51 +0100 Subject: [PATCH] #2913: Fix ActionPenalty. --- src/primaite/game/agent/rewards.py | 35 +++++++++++++++++------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 4198af27..1158d919 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -28,7 +28,7 @@ the structure: ``` """ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, ClassVar, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union from pydantic import BaseModel from typing_extensions import Never @@ -51,6 +51,8 @@ class AbstractReward(BaseModel): type: str + _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) if identifier in cls._registry: @@ -58,7 +60,6 @@ class AbstractReward(BaseModel): cls._registry[identifier] = cls @classmethod - @abstractmethod def from_config(cls, config: Dict) -> "AbstractReward": """Create a reward function component from a config dictionary. @@ -69,7 +70,8 @@ class AbstractReward(BaseModel): """ if config["type"] not in cls._registry: raise ValueError(f"Invalid reward type {config['type']}") - + adder_class = cls._registry[config["type"]] + adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config)) return cls @abstractmethod @@ -276,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" - node_hostname: str + node_hostname: str = "" sticky: bool = True def __init__(self, node_hostname: str, sticky: bool = True) -> None: @@ -494,6 +496,8 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" + action_penalty: float = -1.0 + do_nothing_penalty: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for ActionPenalty.""" @@ -501,19 +505,20 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 - def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: - """ - Initialise the reward. + # def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: + # """ + # Initialise the reward. - Reward or penalise agents for doing nothing or taking actions. + # Reward or penalise agents for doing nothing or taking actions. - :param action_penalty: Reward to give agents for taking any action except DONOTHING - :type action_penalty: float - :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action - :type do_nothing_penalty: float - """ - self.action_penalty = action_penalty - self.do_nothing_penalty = do_nothing_penalty + # :param action_penalty: Reward to give agents for taking any action except DONOTHING + # :type action_penalty: float + # :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action + # :type do_nothing_penalty: float + # """ + # super().__init__(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) + # self.action_penalty = action_penalty + # self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied.