Add shared reward

This commit is contained in:
Marek Wolan
2024-03-11 22:53:39 +00:00
parent 7599655879
commit c3f1cfb33d
5 changed files with 242 additions and 38 deletions

View File

@@ -73,7 +73,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
@@ -116,7 +123,14 @@ agents:
reward_function:
reward_components:
- type: DUMMY
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
@@ -696,22 +710,15 @@ agents:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_1
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
agent_name: client_2_green_user
agent_settings:

View File

@@ -26,7 +26,9 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
from typing_extensions import Never
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -214,18 +216,29 @@ class WebpageUnavailablePenalty(AbstractReward):
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
"""
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
"""
Calculate the reward based on current simulation state.
Calculate the reward based on current simulation state, and the recent agent action.
:param state: The current state of the simulation.
:type state: Dict
When the green agent requests to execute the browser application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request did actually go through, then check if the webpage also loaded
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
_LOGGER.info(
@@ -265,20 +278,28 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
"""
self._node = node_hostname
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
def calculate(
self, state: Dict, last_action_response: "AgentActionHistoryItem"
) -> float: # todo maybe make last_action_response optional?
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""
Calculate the reward based on current simulation state.
Calculate the reward based on current simulation state, and the recent agent action.
:param state: The current state of the simulation.
:type state: Dict
When the green agent requests to execute the database client application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
"""
if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]:
pass # TODO
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
@@ -301,6 +322,52 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
return cls(node_hostname=node_hostname)
class SharedReward(AbstractReward):
"""Adds another agent's reward to the overall reward."""
def __init__(self, agent_name: Optional[str] = None) -> None:
"""
Initialise the shared reward.
The agent_ref is a placeholder value. It starts off as none, but it must be set before this reward can work
correctly.
:param agent_name: The name whose reward is an input
:type agent_ref: Optional[str]
"""
self.agent_name = agent_name
"""Agent whose reward to track."""
def default_callback() -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
print(self.callback(), self.agent_name)
return self.callback()
@classmethod
def from_config(cls, config: Dict) -> "SharedReward":
"""
Build the SharedReward object from config.
:param config: Configuration dictionary
:type config: Dict
"""
agent_name = config.get("agent_name")
return cls(agent_name=agent_name)
class RewardFunction:
"""Manages the reward function for the agent."""
@@ -310,6 +377,7 @@ class RewardFunction:
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
}
"""List of reward class identifiers."""

View File

@@ -9,8 +9,9 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.rewards import RewardFunction, SharedReward
from primaite.game.agent.scripted_agents import ProbabilisticAgent
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -110,6 +111,9 @@ class PrimaiteGame:
self.save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
self._reward_calculation_order: List[str] = [name for name in self.agents]
"""Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards."""
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -148,10 +152,11 @@ class PrimaiteGame:
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent_name, agent in self.agents.items():
for agent_name in self._reward_calculation_order:
agent = self.agents[agent_name]
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
agent.update_observation(state=state)
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
agent.reward_function.total_reward += agent.reward_function.current_reward
def apply_agent_actions(self) -> None:
@@ -443,7 +448,51 @@ class PrimaiteGame:
raise ValueError(msg)
game.agents[agent_cfg["ref"]] = new_agent
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
return game
def setup_reward_sharing(self):
"""Do necessary setup to enable reward sharing between agents.
This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1
depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop.
Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is
sharing their reward. This callback is provided by this setup method.
Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards
that rely on the value of another reward are evaluated later.
:raises RuntimeError: If the reward sharing is specified with a cyclic dependency.
"""
# construct dependency graph in the reward sharing between agents.
graph = {}
for name, agent in self.agents.items():
graph[name] = set()
for comp, weight in agent.reward_function.reward_components:
if isinstance(comp, SharedReward):
comp: SharedReward
graph[name].add(comp.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward
# TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches
# the most recent value of current_reward, NOT just simply caching the reward value at the time this
# callback method is defined.
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
if graph_has_cycle(graph):
raise RuntimeError(
(
"Detected cycle in agent reward sharing. Check the agent reward function ",
"configuration: reward sharing can only go one way.",
)
)
# sort the agents so the rewards that depend on other rewards are always evaluated later
self._reward_calculation_order = topological_sort(graph)

View File

@@ -1,4 +1,5 @@
from random import random
from typing import Any, Iterable, Mapping
def simulate_trial(p_of_success: float) -> bool:
@@ -14,3 +15,81 @@ def simulate_trial(p_of_success: float) -> bool:
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
"""
return random() < p_of_success
def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool:
"""Detect cycles in a directed graph.
Provide the graph as a dictionary that describes which nodes are linked. For example:
{0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0
{'a': ('b','c'), c:('b')} here there is no cycle
:param graph: a mapping from node to a set of nodes to which it is connected.
:type graph: Mapping[Any, Iterable[Any]]
:return: Whether the graph has any cycles
:rtype: bool
"""
visited = set()
currently_visiting = set()
def depth_first_search(node: Any) -> bool:
"""Perform depth-first search (DFS) traversal to detect cycles starting from a given node."""
if node in currently_visiting:
return True # Cycle detected
if node in visited:
return False # Already visited, no need to explore further
visited.add(node)
currently_visiting.add(node)
for neighbour in graph.get(node, []):
if depth_first_search(neighbour):
return True # Cycle detected
currently_visiting.remove(node)
return False
# Start DFS traversal from each node
for node in graph:
if depth_first_search(node):
return True # Cycle detected
return False # No cycles found
def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
"""
Perform topological sorting on a directed graph.
This guarantees that if there's a directed edge from node A to node B, then A appears before B.
:param graph: A dictionary representing the directed graph, where keys are node identifiers
and values are lists of outgoing edges from each node.
:type graph: dict[int, list[Any]]
:return: A topologically sorted list of node identifiers.
:rtype: list[Any]
"""
visited: set[Any] = set()
stack: list[Any] = []
def dfs(node: Any) -> None:
"""
Depth-first search traversal to visit nodes and their neighbors.
:param node: The current node to visit.
:type node: Any
"""
if node in visited:
return
visited.add(node)
for neighbour in graph.get(node, []):
dfs(neighbour)
stack.append(node)
# Perform DFS traversal from each node
for node in graph:
dfs(node)
# Reverse the stack and return it.
return stack[::-1]

View File

@@ -77,16 +77,17 @@ class PrimaiteIO:
:type episode: int
"""
data = {}
longest_history = max([len(hist) for hist in agent_actions])
longest_history = max([len(hist) for hist in agent_actions.values()])
for i in range(longest_history):
data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}}
data[i] = {"timestep": i, "episode": episode}
data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i})
path = self.generate_agent_actions_save_path(episode=episode)
path.parent.mkdir(exist_ok=True, parents=True)
path.touch()
_LOGGER.info(f"Saving agent action log to {path}")
with open(path, "w") as file:
json.dump(data, fp=file, indent=1)
json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump())
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":