diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 6cded5f2..1ec98f39 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -740,7 +740,6 @@ agents: agent_name: client_2_green_user - agent_settings: flatten_obs: true diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index cabea5f4..7d14e097 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -360,6 +360,38 @@ class SharedReward(AbstractReward): return cls(agent_name=agent_name) +class ActionPenalty(AbstractReward): + """ + Apply a negative reward when taking any action except DONOTHING. + + Optional Configuration item therefore default value of 0 (?). + """ + + def __init__(self, agent_name: str, penalty: float = 0): + """ + Initialise the reward. + + Penalty will default to 0, as this is an optional param. + """ + self.agent_name = agent_name + self.penalty = penalty + + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the penalty to be applied.""" + if last_action_response.action == "DONOTHING": + # No penalty for doing nothing at present + return 0 + else: + return -1 + + @classmethod + def from_config(cls, config: Dict) -> "ActionPenalty": + """Build the ActionPenalty object from config.""" + agent_name = config.get("agent_name") + # penalty_value = config.get("ACTION_PENALTY", 0) + return cls(agent_name=agent_name) + + class RewardFunction: """Manages the reward function for the agent.""" @@ -370,6 +402,7 @@ class RewardFunction: "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, "SHARED_REWARD": SharedReward, + "ACTION_PENALTY": ActionPenalty, } """List of reward class identifiers."""