diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 8ac3956c..321df098 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -180,13 +180,17 @@ class WebServer404Penalty(AbstractReward): return 0.0 codes = web_service_state.get("response_codes_this_timestep") - if codes or not self.sticky: # skip calculating if sticky and no new codes. Insted, reuse last step's value. + if codes: def status2rew(status: int) -> int: """Map status codes to reward values.""" return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average + elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 + self.reward = 0 + else: # skip calculating if sticky and no new codes. insted, reuse last step's value + pass return self.reward @@ -319,13 +323,6 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): request returned was able to connect to the database server, because there has been an unsuccessful request since. """ - db_state = access_from_nested_dict(state, self.location_in_state) - - # If the last request was actually sent, then check if the connection was established. - if db_state is NOT_PRESENT_IN_STATE: - _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") - self.reward = 0.0 - request_attempted = last_action_response.request == [ "network", "node", @@ -335,11 +332,13 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): "execute", ] - # skip calculating if sticky and no new codes, reusing last step value - if not request_attempted and self.sticky: - return self.reward + if request_attempted: # if agent makes request, always recalculate fresh value + self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 + elif not self.sticky: # if no new request and not sticky, set reward to 0 + self.reward = 0.0 + else: # if no new request and sticky, reuse reward value from last step + pass - self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 return self.reward @classmethod diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py new file mode 100644 index 00000000..58f0fcc1 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -0,0 +1,299 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +from primaite.game.agent.interface import AgentHistoryItem +from primaite.game.agent.rewards import ( + GreenAdminDatabaseUnreachablePenalty, + WebpageUnavailablePenalty, + WebServer404Penalty, +) +from primaite.interface.request import RequestResponse + + +class TestWebServer404PenaltySticky: + def test_non_sticky(self): + reward = WebServer404Penalty("computer", "WebService", sticky=False) + + # no response codes yet, reward is 0 + codes = [] + state = { + "network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}} + } + last_action_response = None + assert reward.calculate(state, last_action_response) == 0 + + # update codes (by reference), 200 response code is now present + codes.append(200) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # update codes (by reference), to make it empty again, reward goes back to 0 + codes.pop() + assert reward.calculate(state, last_action_response) == 0.0 + + # update codes (by reference), 404 response code is now present, reward = -1.0 + codes.append(404) + assert reward.calculate(state, last_action_response) == -1.0 + + # don't update codes, it still has just a 404, check the reward is -1.0 again + assert reward.calculate(state, last_action_response) == -1.0 + + def test_sticky(self): + reward = WebServer404Penalty("computer", "WebService", sticky=True) + + # no response codes yet, reward is 0 + codes = [] + state = { + "network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}} + } + last_action_response = None + assert reward.calculate(state, last_action_response) == 0 + + # update codes (by reference), 200 response code is now present + codes.append(200) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # update codes (by reference), to make it empty again, reward remains at 1.0 because it's sticky + codes.pop() + assert reward.calculate(state, last_action_response) == 1.0 + + # update codes (by reference), 404 response code is now present, reward = -1.0 + codes.append(404) + assert reward.calculate(state, last_action_response) == -1.0 + + # don't update codes, it still has just a 404, check the reward is -1.0 again + assert reward.calculate(state, last_action_response) == -1.0 + + +class TestWebpageUnavailabilitySticky: + def test_non_sticky(self): + reward = WebpageUnavailablePenalty("computer", sticky=False) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="success", data={}) + browser_history.append({"outcome": 200}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is not sticky, it goes back to 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + def test_sticky(self): + reward = WebpageUnavailablePenalty("computer", sticky=True) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="success", data={}) + browser_history.append({"outcome": 200}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is sticky, it stays at 1.0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + +class TestGreenAdminDatabaseUnreachableSticky: + def test_non_sticky(self): + reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is not sticky, it goes back to 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + def test_sticky(self): + reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is not sticky, it goes back to 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0