#2913: Fix ActionPenalty.

This commit is contained in:
Nick Todd
2024-10-21 14:43:51 +01:00
parent 419a86114d
commit bbcbb26f5e

View File

@@ -28,7 +28,7 @@ the structure:
```
"""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, ClassVar, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from pydantic import BaseModel
from typing_extensions import Never
@@ -51,6 +51,8 @@ class AbstractReward(BaseModel):
type: str
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
@@ -58,7 +60,6 @@ class AbstractReward(BaseModel):
cls._registry[identifier] = cls
@classmethod
@abstractmethod
def from_config(cls, config: Dict) -> "AbstractReward":
"""Create a reward function component from a config dictionary.
@@ -69,7 +70,8 @@ class AbstractReward(BaseModel):
"""
if config["type"] not in cls._registry:
raise ValueError(f"Invalid reward type {config['type']}")
adder_class = cls._registry[config["type"]]
adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config))
return cls
@abstractmethod
@@ -276,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebpageUnavailablePenalty."""
node_hostname: str
node_hostname: str = ""
sticky: bool = True
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
@@ -494,6 +496,8 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
"""Apply a negative reward when taking any action except DONOTHING."""
action_penalty: float = -1.0
do_nothing_penalty: float = 0.0
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for ActionPenalty."""
@@ -501,19 +505,20 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
action_penalty: float = -1.0
do_nothing_penalty: float = 0.0
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
"""
Initialise the reward.
# def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
# """
# Initialise the reward.
Reward or penalise agents for doing nothing or taking actions.
# Reward or penalise agents for doing nothing or taking actions.
:param action_penalty: Reward to give agents for taking any action except DONOTHING
:type action_penalty: float
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
:type do_nothing_penalty: float
"""
self.action_penalty = action_penalty
self.do_nothing_penalty = do_nothing_penalty
# :param action_penalty: Reward to give agents for taking any action except DONOTHING
# :type action_penalty: float
# :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
# :type do_nothing_penalty: float
# """
# super().__init__(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
# self.action_penalty = action_penalty
# self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied.