From 24fdb8dc17f2a7608e7876a58afa054993e30851 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 12 Mar 2024 11:40:26 +0000 Subject: [PATCH] Fix minor reward sharing bugs --- src/primaite/game/agent/rewards.py | 8 ++--- src/primaite/game/game.py | 5 +-- src/primaite/game/science.py | 3 +- .../Data-Manipulation-E2E-Demonstration.ipynb | 32 +++++++++++-------- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 3d61c0b4..a2ffd875 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -230,7 +230,7 @@ class WebpageUnavailablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last webpage had a 200 status code, because there has been an unsuccessful request since. """ - if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: + if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]: self._last_request_failed = last_action_response.response.status != "success" # if agent couldn't even get as far as sending the request (because for example the node was off), then @@ -338,7 +338,7 @@ class SharedReward(AbstractReward): self.agent_name = agent_name """Agent whose reward to track.""" - def default_callback() -> Never: + def default_callback(agent_name: str) -> Never: """ Default callback to prevent calling this reward until it's properly initialised. @@ -348,12 +348,12 @@ class SharedReward(AbstractReward): """ raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - self.callback: Callable[[], float] = default_callback + self.callback: Callable[[str], float] = default_callback """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Simply access the other agent's reward and return it.""" - return self.callback() + return self.callback(self.agent_name) @classmethod def from_config(cls, config: Dict) -> "SharedReward": diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index e766bcd3..ac23610c 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -480,10 +480,7 @@ class PrimaiteGame: graph[name].add(comp.agent_name) # while constructing the graph, we might as well set up the reward sharing itself. - comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward - # TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches - # the most recent value of current_reward, NOT just simply caching the reward value at the time this - # callback method is defined. + comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward # make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing. if graph_has_cycle(graph): diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py index 801ef269..908b326f 100644 --- a/src/primaite/game/science.py +++ b/src/primaite/game/science.py @@ -91,5 +91,4 @@ def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]: for node in graph: dfs(node) - # Reverse the stack and return it. - return stack[::-1] + return stack diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index b2522c2b..946202b6 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -450,7 +450,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now the reward is -1, let's have a look at blue agent's observation." + "Now the reward is -0.8, let's have a look at blue agent's observation." ] }, { @@ -510,9 +510,9 @@ "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", - "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_1_green_user']['action']}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_2_green_user']['action']}\" )\n", + "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n", "print(f\"Blue reward:{reward}\" )" ] }, @@ -533,9 +533,9 @@ "metadata": {}, "outputs": [], "source": [ - "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", + "obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", "print(f\"step: {env.game.step_counter}\")\n", - "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker']['action']}\" )\n", + "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n", "print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n", "print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n", "print(f\"Blue reward:{reward:.2f}\" )" @@ -557,17 +557,19 @@ "outputs": [], "source": [ "env.step(13) # Patch the database\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(50) # Block client 1\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(51) # Block client 2\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", - "for step in range(30):\n", + "while abs(reward - 0.8) > 1e-5:\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", + " if env.game.step_counter > 10000:\n", + " break # make sure there's no infinite loop if something went wrong" ] }, { @@ -617,17 +619,19 @@ " if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", " # client 1 has NMNEs, let's block it\n", " obs, reward, terminated, truncated, info = env.step(50) # block client 1\n", + " print(\"blocking client 1\")\n", " break\n", " elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", " # client 2 has NMNEs, so let's block it\n", " obs, reward, terminated, truncated, info = env.step(51) # block client 2\n", + " print(\"blocking client 2\")\n", " break\n", " if tries>100:\n", " print(\"Error: NMNE never increased\")\n", " break\n", "\n", "env.step(13) # Patch the database\n", - "..." + "print()\n" ] }, { @@ -646,14 +650,14 @@ "\n", "for step in range(40):\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker']['action']}, Blue reward:{reward:.2f}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode." + "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)" ] }, {