#2913: Make rewards work with config file.

This commit is contained in:
Nick Todd
2024-11-06 11:35:06 +00:00
parent b849ea6312
commit 370bcfc476
3 changed files with 23 additions and 32 deletions

View File

@@ -79,9 +79,8 @@ class AbstractReward(BaseModel):
if config["type"] not in cls._registry:
raise ValueError(f"Invalid reward type {config['type']}")
reward_class = cls._registry[config["type"]]
reward_config = reward_class.ConfigSchema(**config)
reward_class(config=reward_config)
return reward_class
reward_obj = reward_class(config=reward_class.ConfigSchema(**config))
return reward_obj
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
@@ -97,7 +96,7 @@ class AbstractReward(BaseModel):
return 0.0
class DummyReward(AbstractReward, identifier="DummyReward"):
class DummyReward(AbstractReward, identifier="DUMMY"):
"""Dummy reward function component which always returns 0.0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
@@ -113,7 +112,7 @@ class DummyReward(AbstractReward, identifier="DummyReward"):
return 0.0
class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
config: "DatabaseFileIntegrity.ConfigSchema"
@@ -123,7 +122,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for DatabaseFileIntegrity."""
type: str = "DatabaseFileIntegrity"
type: str = "DATABASE_FILE_INTEGRITY"
node_hostname: str
folder_name: str
file_name: str
@@ -166,7 +165,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
return 0
class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
config: "WebServer404Penalty.ConfigSchema"
@@ -176,7 +175,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebServer404Penalty."""
type: str = "WebServer404Penalty"
type: str = "WEB_SERVER_404_PENALTY"
node_hostname: str
service_name: str
sticky: bool = True
@@ -220,7 +219,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
return self.reward
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"):
"""Penalises the agent when the web browser fails to fetch a webpage."""
config: "WebpageUnavailablePenalty.ConfigSchema"
@@ -230,7 +229,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebpageUnavailablePenalty."""
type: str = "WebpageUnavailablePenalty"
type: str = "WEBPAGE_UNAVAILABLE_PENALTY"
node_hostname: str = ""
sticky: bool = True
@@ -294,7 +293,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
return self.reward
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"):
"""Penalises the agent when the green db clients fail to connect to the database."""
config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema"
@@ -303,7 +302,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
type: str = "GreenAdminDatabaseUnreachablePenalty"
type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"
node_hostname: str
sticky: bool = True
@@ -344,7 +343,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
return self.reward
class SharedReward(AbstractReward, identifier="SharedReward"):
class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
"""Adds another agent's reward to the overall reward."""
config: "SharedReward.ConfigSchema"
@@ -352,7 +351,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for SharedReward."""
type: str = "SharedReward"
type: str = "SHARED_REWARD"
agent_name: str
def default_callback(agent_name: str) -> Never:
@@ -381,7 +380,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
return self.callback(self.config.agent_name)
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
"""Apply a negative reward when taking any action except DONOTHING."""
config: "ActionPenalty.ConfigSchema"
@@ -389,7 +388,7 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for ActionPenalty."""
type: str = "ActionPenalty"
type: str = "ACTION_PENALTY"
action_penalty: float = -1.0
do_nothing_penalty: float = 0.0
@@ -412,17 +411,6 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
class RewardFunction:
"""Manages the reward function for the agent."""
rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
"ACTION_PENALTY": ActionPenalty,
}
"""List of reward class identifiers."""
def __init__(self):
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
@@ -457,7 +445,7 @@ class RewardFunction:
@classmethod
def from_config(cls, config: Dict) -> "RewardFunction":
"""Create a reward function from a config dictionary.
"""Create a reward function from a config dictionary and its related reward class.
:param config: dict of options for the reward manager's constructor
:type config: Dict
@@ -468,8 +456,11 @@ class RewardFunction:
for rew_component_cfg in config["reward_components"]:
rew_type = rew_component_cfg["type"]
# XXX: If options key is missing add key then add type key.
if "options" not in rew_component_cfg:
rew_component_cfg["options"] = {}
rew_component_cfg["options"]["type"] = rew_type
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}))
rew_instance = AbstractReward.from_config(rew_component_cfg["options"])
new.register_component(component=rew_instance, weight=weight)
return new

View File

@@ -629,7 +629,7 @@ class PrimaiteGame:
for comp, weight in agent.reward_function.reward_components:
if isinstance(comp, SharedReward):
comp: SharedReward
graph[name].add(comp.agent_name)
graph[name].add(comp.config.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward

View File

@@ -25,7 +25,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
agent: ControlledAgent
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
comp = WebpageUnavailablePenalty(config=schema)
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()