#2913: Fix ActionPenalty.
This commit is contained in:
@@ -28,7 +28,7 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
from typing import Any, ClassVar, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Never
|
||||
@@ -51,6 +51,8 @@ class AbstractReward(BaseModel):
|
||||
|
||||
type: str
|
||||
|
||||
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
@@ -58,7 +60,6 @@ class AbstractReward(BaseModel):
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict) -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
@@ -69,7 +70,8 @@ class AbstractReward(BaseModel):
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid reward type {config['type']}")
|
||||
|
||||
adder_class = cls._registry[config["type"]]
|
||||
adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config))
|
||||
return cls
|
||||
|
||||
@abstractmethod
|
||||
@@ -276,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebpageUnavailablePenalty."""
|
||||
|
||||
node_hostname: str
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
@@ -494,6 +496,8 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
|
||||
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for ActionPenalty."""
|
||||
@@ -501,19 +505,20 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
|
||||
"""
|
||||
Initialise the reward.
|
||||
# def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
|
||||
# """
|
||||
# Initialise the reward.
|
||||
|
||||
Reward or penalise agents for doing nothing or taking actions.
|
||||
# Reward or penalise agents for doing nothing or taking actions.
|
||||
|
||||
:param action_penalty: Reward to give agents for taking any action except DONOTHING
|
||||
:type action_penalty: float
|
||||
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
|
||||
:type do_nothing_penalty: float
|
||||
"""
|
||||
self.action_penalty = action_penalty
|
||||
self.do_nothing_penalty = do_nothing_penalty
|
||||
# :param action_penalty: Reward to give agents for taking any action except DONOTHING
|
||||
# :type action_penalty: float
|
||||
# :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
|
||||
# :type do_nothing_penalty: float
|
||||
# """
|
||||
# super().__init__(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
|
||||
# self.action_penalty = action_penalty
|
||||
# self.do_nothing_penalty = do_nothing_penalty
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the penalty to be applied.
|
||||
|
||||
Reference in New Issue
Block a user