#2736 - Add sticky reward tests and fix sticky reward behaviour
This commit is contained in:
@@ -180,13 +180,17 @@ class WebServer404Penalty(AbstractReward):
|
||||
return 0.0
|
||||
|
||||
codes = web_service_state.get("response_codes_this_timestep")
|
||||
if codes or not self.sticky: # skip calculating if sticky and no new codes. Insted, reuse last step's value.
|
||||
if codes:
|
||||
|
||||
def status2rew(status: int) -> int:
|
||||
"""Map status codes to reward values."""
|
||||
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
|
||||
|
||||
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
|
||||
elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0
|
||||
self.reward = 0
|
||||
else: # skip calculating if sticky and no new codes. insted, reuse last step's value
|
||||
pass
|
||||
|
||||
return self.reward
|
||||
|
||||
@@ -319,13 +323,6 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
request returned was able to connect to the database server, because there has been an unsuccessful request
|
||||
since.
|
||||
"""
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
if db_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
self.reward = 0.0
|
||||
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
@@ -335,11 +332,13 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
"execute",
|
||||
]
|
||||
|
||||
# skip calculating if sticky and no new codes, reusing last step value
|
||||
if not request_attempted and self.sticky:
|
||||
return self.reward
|
||||
|
||||
if request_attempted: # if agent makes request, always recalculate fresh value
|
||||
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
|
||||
elif not self.sticky: # if no new request and not sticky, set reward to 0
|
||||
self.reward = 0.0
|
||||
else: # if no new request and sticky, reuse reward value from last step
|
||||
pass
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
|
||||
299
tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py
Normal file
299
tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import (
|
||||
GreenAdminDatabaseUnreachablePenalty,
|
||||
WebpageUnavailablePenalty,
|
||||
WebServer404Penalty,
|
||||
)
|
||||
from primaite.interface.request import RequestResponse
|
||||
|
||||
|
||||
class TestWebServer404PenaltySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
state = {
|
||||
"network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}}
|
||||
}
|
||||
last_action_response = None
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# update codes (by reference), 200 response code is now present
|
||||
codes.append(200)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# update codes (by reference), to make it empty again, reward goes back to 0
|
||||
codes.pop()
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# update codes (by reference), 404 response code is now present, reward = -1.0
|
||||
codes.append(404)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# don't update codes, it still has just a 404, check the reward is -1.0 again
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
state = {
|
||||
"network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}}
|
||||
}
|
||||
last_action_response = None
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# update codes (by reference), 200 response code is now present
|
||||
codes.append(200)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# update codes (by reference), to make it empty again, reward remains at 1.0 because it's sticky
|
||||
codes.pop()
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# update codes (by reference), 404 response code is now present, reward = -1.0
|
||||
codes.append(404)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# don't update codes, it still has just a 404, check the reward is -1.0 again
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
|
||||
class TestWebpageUnavailabilitySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history.append({"outcome": 200})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history.append({"outcome": 200})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is sticky, it stays at 1.0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
|
||||
class TestGreenAdminDatabaseUnreachableSticky:
|
||||
def test_non_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
Reference in New Issue
Block a user