diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 029597a0..05bca033 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -211,7 +211,7 @@ class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average - elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 + elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0 self.reward = 0.0 else: # skip calculating if sticky and no new codes. instead, reuse last step's value pass @@ -278,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_ elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", - f"{self.node_hostname} was not reported in the simulation state. Returning 0.0", + f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 else: diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py index 2ad1a322..c758291f 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse class TestWebServer404PenaltySticky: def test_non_sticky(self): - reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=False) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=False, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=True) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=True, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -67,7 +77,8 @@ class TestWebServer404PenaltySticky: class TestWebpageUnavailabilitySticky: def test_non_sticky(self): - reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=False) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -127,7 +138,8 @@ class TestWebpageUnavailabilitySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=True) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -188,7 +200,11 @@ class TestWebpageUnavailabilitySticky: class TestGreenAdminDatabaseUnreachableSticky: def test_non_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=False) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=False, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -214,7 +230,6 @@ class TestGreenAdminDatabaseUnreachableSticky: # agent did nothing, because reward is not sticky, it goes back to 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] response = RequestResponse(status="success", data={}) - browser_history = [] state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response @@ -244,7 +259,11 @@ class TestGreenAdminDatabaseUnreachableSticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=True) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=True, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"]