diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 1158d919..9777441b 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -274,28 +274,34 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" + node_hostname: str = "" + sticky: bool = True + reward: float = 0.0 + location_in_state: List[str] = [""] + _node: str = node_hostname class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" node_hostname: str = "" sticky: bool = True + reward: float = 0.0 - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + # def __init__(self, node_hostname: str, sticky: bool = True) -> None: + # """ + # Initialise the reward component. - :param node_hostname: Hostname of the node which has the web browser. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + # :param node_hostname: Hostname of the node which has the web browser. + # :type node_hostname: str + # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + # the reward if there were any responses this timestep. + # :type sticky: bool + # """ + # self._node: str = node_hostname + # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + # self.sticky: bool = sticky + # self.reward: float = 0.0 + # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -311,6 +317,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe :return: Reward value :rtype: float """ + self.location_in_state: List[str] = ["network", "nodes", self.node_hostname, "applications", "WebBrowser"] web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE: @@ -364,6 +371,10 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" + node_hostname: str = "" + _node: str = node_hostname + sticky: bool = True + reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" @@ -371,21 +382,21 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi node_hostname: str sticky: bool = True - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + # def __init__(self, node_hostname: str, sticky: bool = True) -> None: + # """ + # Initialise the reward component. - :param node_hostname: Hostname of the node where the database client sits. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + # :param node_hostname: Hostname of the node where the database client sits. + # :type node_hostname: str + # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + # the reward if there were any responses this timestep. + # :type sticky: bool + # """ + # self._node: str = node_hostname + # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + # self.sticky: bool = sticky + # self.reward: float = 0.0 + # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -438,37 +449,38 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" + agent_name: str class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for SharedReward.""" agent_name: str - def __init__(self, agent_name: Optional[str] = None) -> None: + # def __init__(self, agent_name: Optional[str] = None) -> None: + # """ + # Initialise the shared reward. + + # The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work + # correctly. + + # :param agent_name: The name whose reward is an input + # :type agent_name: Optional[str] + # """ + # # self.agent_name = agent_name + # """Agent whose reward to track.""" + + def default_callback(agent_name: str) -> Never: """ - Initialise the shared reward. + Default callback to prevent calling this reward until it's properly initialised. - The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work - correctly. - - :param agent_name: The name whose reward is an input - :type agent_name: Optional[str] + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. """ - self.agent_name = agent_name - """Agent whose reward to track.""" + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - def default_callback(agent_name: str) -> Never: - """ - Default callback to prevent calling this reward until it's properly initialised. - - SharedReward should not be used until the game layer replaces self.callback with a reference to the - function that retrieves the desired agent's reward. Therefore, we define this default callback that raises - an error. - """ - raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - - self.callback: Callable[[str], float] = default_callback - """Method that retrieves an agent's current reward given the agent's name.""" + callback: Callable[[str], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it. diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 0005b508..bf707feb 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -74,7 +74,7 @@ def test_uc2_rewards(game_and_agent): ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) - comp = GreenAdminDatabaseUnreachablePenalty("client_1") + comp = GreenAdminDatabaseUnreachablePenalty(node_hostname="client_1") request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] response = game.simulation.apply_request(request)