#2913: Make rewards work with config file.
This commit is contained in:
@@ -79,9 +79,8 @@ class AbstractReward(BaseModel):
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid reward type {config['type']}")
|
||||
reward_class = cls._registry[config["type"]]
|
||||
reward_config = reward_class.ConfigSchema(**config)
|
||||
reward_class(config=reward_config)
|
||||
return reward_class
|
||||
reward_obj = reward_class(config=reward_class.ConfigSchema(**config))
|
||||
return reward_obj
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
@@ -97,7 +96,7 @@ class AbstractReward(BaseModel):
|
||||
return 0.0
|
||||
|
||||
|
||||
class DummyReward(AbstractReward, identifier="DummyReward"):
|
||||
class DummyReward(AbstractReward, identifier="DUMMY"):
|
||||
"""Dummy reward function component which always returns 0.0."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
@@ -113,7 +112,7 @@ class DummyReward(AbstractReward, identifier="DummyReward"):
|
||||
return 0.0
|
||||
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
config: "DatabaseFileIntegrity.ConfigSchema"
|
||||
@@ -123,7 +122,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseFileIntegrity."""
|
||||
|
||||
type: str = "DatabaseFileIntegrity"
|
||||
type: str = "DATABASE_FILE_INTEGRITY"
|
||||
node_hostname: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
@@ -166,7 +165,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
return 0
|
||||
|
||||
|
||||
class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
config: "WebServer404Penalty.ConfigSchema"
|
||||
@@ -176,7 +175,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebServer404Penalty."""
|
||||
|
||||
type: str = "WebServer404Penalty"
|
||||
type: str = "WEB_SERVER_404_PENALTY"
|
||||
node_hostname: str
|
||||
service_name: str
|
||||
sticky: bool = True
|
||||
@@ -220,7 +219,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
return self.reward
|
||||
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
|
||||
class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
|
||||
config: "WebpageUnavailablePenalty.ConfigSchema"
|
||||
@@ -230,7 +229,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebpageUnavailablePenalty."""
|
||||
|
||||
type: str = "WebpageUnavailablePenalty"
|
||||
type: str = "WEBPAGE_UNAVAILABLE_PENALTY"
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
|
||||
@@ -294,7 +293,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
return self.reward
|
||||
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema"
|
||||
@@ -303,7 +302,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
|
||||
|
||||
type: str = "GreenAdminDatabaseUnreachablePenalty"
|
||||
type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"
|
||||
node_hostname: str
|
||||
sticky: bool = True
|
||||
|
||||
@@ -344,7 +343,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
return self.reward
|
||||
|
||||
|
||||
class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
config: "SharedReward.ConfigSchema"
|
||||
@@ -352,7 +351,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for SharedReward."""
|
||||
|
||||
type: str = "SharedReward"
|
||||
type: str = "SHARED_REWARD"
|
||||
agent_name: str
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
@@ -381,7 +380,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
return self.callback(self.config.agent_name)
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
|
||||
config: "ActionPenalty.ConfigSchema"
|
||||
@@ -389,7 +388,7 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for ActionPenalty."""
|
||||
|
||||
type: str = "ActionPenalty"
|
||||
type: str = "ACTION_PENALTY"
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
@@ -412,17 +411,6 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
|
||||
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
|
||||
"SHARED_REWARD": SharedReward,
|
||||
"ACTION_PENALTY": ActionPenalty,
|
||||
}
|
||||
"""List of reward class identifiers."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
@@ -457,7 +445,7 @@ class RewardFunction:
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "RewardFunction":
|
||||
"""Create a reward function from a config dictionary.
|
||||
"""Create a reward function from a config dictionary and its related reward class.
|
||||
|
||||
:param config: dict of options for the reward manager's constructor
|
||||
:type config: Dict
|
||||
@@ -468,8 +456,11 @@ class RewardFunction:
|
||||
|
||||
for rew_component_cfg in config["reward_components"]:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
# XXX: If options key is missing add key then add type key.
|
||||
if "options" not in rew_component_cfg:
|
||||
rew_component_cfg["options"] = {}
|
||||
rew_component_cfg["options"]["type"] = rew_type
|
||||
weight = rew_component_cfg.get("weight", 1.0)
|
||||
rew_class = cls.rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}))
|
||||
rew_instance = AbstractReward.from_config(rew_component_cfg["options"])
|
||||
new.register_component(component=rew_instance, weight=weight)
|
||||
return new
|
||||
|
||||
@@ -629,7 +629,7 @@ class PrimaiteGame:
|
||||
for comp, weight in agent.reward_function.reward_components:
|
||||
if isinstance(comp, SharedReward):
|
||||
comp: SharedReward
|
||||
graph[name].add(comp.agent_name)
|
||||
graph[name].add(comp.config.agent_name)
|
||||
|
||||
# while constructing the graph, we might as well set up the reward sharing itself.
|
||||
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
agent: ControlledAgent
|
||||
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
|
||||
comp = WebpageUnavailablePenalty(config=schema)
|
||||
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
|
||||
Reference in New Issue
Block a user