Fix minor reward sharing bugs

This commit is contained in:
Marek Wolan
2024-03-12 11:40:26 +00:00
parent 03ee976a2d
commit 24fdb8dc17
4 changed files with 24 additions and 24 deletions

View File

@@ -230,7 +230,7 @@ class WebpageUnavailablePenalty(AbstractReward):
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
@@ -338,7 +338,7 @@ class SharedReward(AbstractReward):
self.agent_name = agent_name
"""Agent whose reward to track."""
def default_callback() -> Never:
def default_callback(agent_name: str) -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
@@ -348,12 +348,12 @@ class SharedReward(AbstractReward):
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[], float] = default_callback
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
return self.callback()
return self.callback(self.agent_name)
@classmethod
def from_config(cls, config: Dict) -> "SharedReward":

View File

@@ -480,10 +480,7 @@ class PrimaiteGame:
graph[name].add(comp.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward
# TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches
# the most recent value of current_reward, NOT just simply caching the reward value at the time this
# callback method is defined.
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
if graph_has_cycle(graph):

View File

@@ -91,5 +91,4 @@ def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
for node in graph:
dfs(node)
# Reverse the stack and return it.
return stack[::-1]
return stack

View File

@@ -450,7 +450,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now the reward is -1, let's have a look at blue agent's observation."
"Now the reward is -0.8, let's have a look at blue agent's observation."
]
},
{
@@ -510,9 +510,9 @@
"source": [
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n",
"print(f\"Blue reward:{reward}\" )"
]
},
@@ -533,9 +533,9 @@
"metadata": {},
"outputs": [],
"source": [
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
"obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
"print(f\"step: {env.game.step_counter}\")\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
"print(f\"Blue reward:{reward:.2f}\" )"
@@ -557,17 +557,19 @@
"outputs": [],
"source": [
"env.step(13) # Patch the database\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(50) # Block client 1\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"env.step(51) # Block client 2\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
"\n",
"for step in range(30):\n",
"while abs(reward - 0.8) > 1e-5:\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
" if env.game.step_counter > 10000:\n",
" break # make sure there's no infinite loop if something went wrong"
]
},
{
@@ -617,17 +619,19 @@
" if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 1 has NMNEs, let's block it\n",
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
" print(\"blocking client 1\")\n",
" break\n",
" elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
" # client 2 has NMNEs, so let's block it\n",
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
" print(\"blocking client 2\")\n",
" break\n",
" if tries>100:\n",
" print(\"Error: NMNE never increased\")\n",
" break\n",
"\n",
"env.step(13) # Patch the database\n",
"..."
"print()\n"
]
},
{
@@ -646,14 +650,14 @@
"\n",
"for step in range(40):\n",
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)"
]
},
{