diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 03764e4b..029597a0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -79,9 +79,8 @@ class AbstractReward(BaseModel): if config["type"] not in cls._registry: raise ValueError(f"Invalid reward type {config['type']}") reward_class = cls._registry[config["type"]] - reward_config = reward_class.ConfigSchema(**config) - reward_class(config=reward_config) - return reward_class + reward_obj = reward_class(config=reward_class.ConfigSchema(**config)) + return reward_obj @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -97,7 +96,7 @@ class AbstractReward(BaseModel): return 0.0 -class DummyReward(AbstractReward, identifier="DummyReward"): +class DummyReward(AbstractReward, identifier="DUMMY"): """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -113,7 +112,7 @@ class DummyReward(AbstractReward, identifier="DummyReward"): return 0.0 -class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): +class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" config: "DatabaseFileIntegrity.ConfigSchema" @@ -123,7 +122,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for DatabaseFileIntegrity.""" - type: str = "DatabaseFileIntegrity" + type: str = "DATABASE_FILE_INTEGRITY" node_hostname: str folder_name: str file_name: str @@ -166,7 +165,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): return 0 -class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): +class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): """Reward function component which penalises the agent when the web server returns a 404 error.""" config: "WebServer404Penalty.ConfigSchema" @@ -176,7 +175,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebServer404Penalty.""" - type: str = "WebServer404Penalty" + type: str = "WEB_SERVER_404_PENALTY" node_hostname: str service_name: str sticky: bool = True @@ -220,7 +219,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): return self.reward -class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): +class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"): """Penalises the agent when the web browser fails to fetch a webpage.""" config: "WebpageUnavailablePenalty.ConfigSchema" @@ -230,7 +229,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" - type: str = "WebpageUnavailablePenalty" + type: str = "WEBPAGE_UNAVAILABLE_PENALTY" node_hostname: str = "" sticky: bool = True @@ -294,7 +293,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe return self.reward -class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"): """Penalises the agent when the green db clients fail to connect to the database.""" config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" @@ -303,7 +302,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" - type: str = "GreenAdminDatabaseUnreachablePenalty" + type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY" node_hostname: str sticky: bool = True @@ -344,7 +343,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi return self.reward -class SharedReward(AbstractReward, identifier="SharedReward"): +class SharedReward(AbstractReward, identifier="SHARED_REWARD"): """Adds another agent's reward to the overall reward.""" config: "SharedReward.ConfigSchema" @@ -352,7 +351,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for SharedReward.""" - type: str = "SharedReward" + type: str = "SHARED_REWARD" agent_name: str def default_callback(agent_name: str) -> Never: @@ -381,7 +380,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"): return self.callback(self.config.agent_name) -class ActionPenalty(AbstractReward, identifier="ActionPenalty"): +class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"): """Apply a negative reward when taking any action except DONOTHING.""" config: "ActionPenalty.ConfigSchema" @@ -389,7 +388,7 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for ActionPenalty.""" - type: str = "ActionPenalty" + type: str = "ACTION_PENALTY" action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 @@ -412,17 +411,6 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): class RewardFunction: """Manages the reward function for the agent.""" - rew_class_identifiers: Dict[str, Type[AbstractReward]] = { - "DUMMY": DummyReward, - "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, - "WEB_SERVER_404_PENALTY": WebServer404Penalty, - "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, - "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, - "SHARED_REWARD": SharedReward, - "ACTION_PENALTY": ActionPenalty, - } - """List of reward class identifiers.""" - def __init__(self): """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] @@ -457,7 +445,7 @@ class RewardFunction: @classmethod def from_config(cls, config: Dict) -> "RewardFunction": - """Create a reward function from a config dictionary. + """Create a reward function from a config dictionary and its related reward class. :param config: dict of options for the reward manager's constructor :type config: Dict @@ -468,8 +456,11 @@ class RewardFunction: for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] + # XXX: If options key is missing add key then add type key. + if "options" not in rew_component_cfg: + rew_component_cfg["options"] = {} + rew_component_cfg["options"]["type"] = rew_type weight = rew_component_cfg.get("weight", 1.0) - rew_class = cls.rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {})) + rew_instance = AbstractReward.from_config(rew_component_cfg["options"]) new.register_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 7c5c93bc..51d0306c 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -629,7 +629,7 @@ class PrimaiteGame: for comp, weight in agent.reward_function.reward_components: if isinstance(comp, SharedReward): comp: SharedReward - graph[name].add(comp.agent_name) + graph[name].add(comp.config.agent_name) # while constructing the graph, we might as well set up the reward sharing itself. comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index d4236d1b..6544c82d 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -25,7 +25,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): agent: ControlledAgent schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) comp = WebpageUnavailablePenalty(config=schema) - + client_1 = game.simulation.network.get_node_by_hostname("client_1") browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run()