Fix minor reward sharing bugs
This commit is contained in:
@@ -230,7 +230,7 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last webpage
|
||||
had a 200 status code, because there has been an unsuccessful request since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
@@ -338,7 +338,7 @@ class SharedReward(AbstractReward):
|
||||
self.agent_name = agent_name
|
||||
"""Agent whose reward to track."""
|
||||
|
||||
def default_callback() -> Never:
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
@@ -348,12 +348,12 @@ class SharedReward(AbstractReward):
|
||||
"""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
self.callback: Callable[[], float] = default_callback
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
return self.callback()
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "SharedReward":
|
||||
|
||||
@@ -480,10 +480,7 @@ class PrimaiteGame:
|
||||
graph[name].add(comp.agent_name)
|
||||
|
||||
# while constructing the graph, we might as well set up the reward sharing itself.
|
||||
comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward
|
||||
# TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches
|
||||
# the most recent value of current_reward, NOT just simply caching the reward value at the time this
|
||||
# callback method is defined.
|
||||
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward
|
||||
|
||||
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
|
||||
if graph_has_cycle(graph):
|
||||
|
||||
@@ -91,5 +91,4 @@ def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
|
||||
for node in graph:
|
||||
dfs(node)
|
||||
|
||||
# Reverse the stack and return it.
|
||||
return stack[::-1]
|
||||
return stack
|
||||
|
||||
@@ -450,7 +450,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now the reward is -1, let's have a look at blue agent's observation."
|
||||
"Now the reward is -0.8, let's have a look at blue agent's observation."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -510,9 +510,9 @@
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(13) # patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n",
|
||||
"print(f\"Blue reward:{reward}\" )"
|
||||
]
|
||||
},
|
||||
@@ -533,9 +533,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # patch the database\n",
|
||||
"obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
"print(f\"step: {env.game.step_counter}\")\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n",
|
||||
"print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n",
|
||||
"print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n",
|
||||
"print(f\"Blue reward:{reward:.2f}\" )"
|
||||
@@ -557,17 +557,19 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(13) # Patch the database\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(50) # Block client 1\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"env.step(51) # Block client 2\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n",
|
||||
"print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
"\n",
|
||||
"for step in range(30):\n",
|
||||
"while abs(reward - 0.8) > 1e-5:\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n",
|
||||
" if env.game.step_counter > 10000:\n",
|
||||
" break # make sure there's no infinite loop if something went wrong"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -617,17 +619,19 @@
|
||||
" if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 1 has NMNEs, let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(50) # block client 1\n",
|
||||
" print(\"blocking client 1\")\n",
|
||||
" break\n",
|
||||
" elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n",
|
||||
" # client 2 has NMNEs, so let's block it\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(51) # block client 2\n",
|
||||
" print(\"blocking client 2\")\n",
|
||||
" break\n",
|
||||
" if tries>100:\n",
|
||||
" print(\"Error: NMNE never increased\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"env.step(13) # Patch the database\n",
|
||||
"..."
|
||||
"print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -646,14 +650,14 @@
|
||||
"\n",
|
||||
"for step in range(40):\n",
|
||||
" obs, reward, terminated, truncated, info = env.step(0) # do nothing\n",
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )"
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode."
|
||||
"Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user