Add shared reward
This commit is contained in:
@@ -73,7 +73,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
@@ -116,7 +123,14 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
|
||||
|
||||
|
||||
@@ -696,22 +710,15 @@ agents:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
agent_name: client_1_green_user
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
agent_name: client_2_green_user
|
||||
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -26,7 +26,9 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from typing_extensions import Never
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -214,18 +216,29 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
:param node_hostname: Hostname of the node which has the web browser.
|
||||
:type node_hostname: str
|
||||
"""
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
When the green agent requests to execute the browser application, and that request fails, this reward
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last webpage
|
||||
had a 200 status code, because there has been an unsuccessful request since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request did actually go through, then check if the webpage also loaded
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
|
||||
_LOGGER.info(
|
||||
@@ -265,20 +278,28 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
"""
|
||||
self._node = node_hostname
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(
|
||||
self, state: Dict, last_action_response: "AgentActionHistoryItem"
|
||||
) -> float: # todo maybe make last_action_response optional?
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state.
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
When the green agent requests to execute the database client application, and that request fails, this reward
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last successful
|
||||
request returned was able to connect to the database server, because there has been an unsuccessful request
|
||||
since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]:
|
||||
pass # TODO
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
@@ -301,6 +322,52 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
return cls(node_hostname=node_hostname)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Initialise the shared reward.
|
||||
|
||||
The agent_ref is a placeholder value. It starts off as none, but it must be set before this reward can work
|
||||
correctly.
|
||||
|
||||
:param agent_name: The name whose reward is an input
|
||||
:type agent_ref: Optional[str]
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
"""Agent whose reward to track."""
|
||||
|
||||
def default_callback() -> Never:
|
||||
"""
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
SharedReward should not be used until the game layer replaces self.callback with a reference to the
|
||||
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
|
||||
an error.
|
||||
"""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
self.callback: Callable[[], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
print(self.callback(), self.agent_name)
|
||||
return self.callback()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "SharedReward":
|
||||
"""
|
||||
Build the SharedReward object from config.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:type config: Dict
|
||||
"""
|
||||
agent_name = config.get("agent_name")
|
||||
return cls(agent_name=agent_name)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
@@ -310,6 +377,7 @@ class RewardFunction:
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
|
||||
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
|
||||
"SHARED_REWARD": SharedReward,
|
||||
}
|
||||
"""List of reward class identifiers."""
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.rewards import RewardFunction, SharedReward
|
||||
from primaite.game.agent.scripted_agents import ProbabilisticAgent
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
@@ -110,6 +111,9 @@ class PrimaiteGame:
|
||||
self.save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
|
||||
self._reward_calculation_order: List[str] = [name for name in self.agents]
|
||||
"""Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards."""
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -148,10 +152,11 @@ class PrimaiteGame:
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for agent_name, agent in self.agents.items():
|
||||
for agent_name in self._reward_calculation_order:
|
||||
agent = self.agents[agent_name]
|
||||
if self.step_counter > 0: # can't get reward before first action
|
||||
agent.update_reward(state=state)
|
||||
agent.update_observation(state=state)
|
||||
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
def apply_agent_actions(self) -> None:
|
||||
@@ -443,7 +448,51 @@ class PrimaiteGame:
|
||||
raise ValueError(msg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
|
||||
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
|
||||
game.setup_reward_sharing()
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
|
||||
return game
|
||||
|
||||
def setup_reward_sharing(self):
|
||||
"""Do necessary setup to enable reward sharing between agents.
|
||||
|
||||
This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1
|
||||
depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop.
|
||||
|
||||
Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is
|
||||
sharing their reward. This callback is provided by this setup method.
|
||||
|
||||
Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards
|
||||
that rely on the value of another reward are evaluated later.
|
||||
|
||||
:raises RuntimeError: If the reward sharing is specified with a cyclic dependency.
|
||||
"""
|
||||
# construct dependency graph in the reward sharing between agents.
|
||||
graph = {}
|
||||
for name, agent in self.agents.items():
|
||||
graph[name] = set()
|
||||
for comp, weight in agent.reward_function.reward_components:
|
||||
if isinstance(comp, SharedReward):
|
||||
comp: SharedReward
|
||||
graph[name].add(comp.agent_name)
|
||||
|
||||
# while constructing the graph, we might as well set up the reward sharing itself.
|
||||
comp.callback = lambda: self.agents[comp.agent_name].reward_function.current_reward
|
||||
# TODO: make sure this lambda is working like I think it does -> it goes to the agent and fetches
|
||||
# the most recent value of current_reward, NOT just simply caching the reward value at the time this
|
||||
# callback method is defined.
|
||||
|
||||
# make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing.
|
||||
if graph_has_cycle(graph):
|
||||
raise RuntimeError(
|
||||
(
|
||||
"Detected cycle in agent reward sharing. Check the agent reward function ",
|
||||
"configuration: reward sharing can only go one way.",
|
||||
)
|
||||
)
|
||||
|
||||
# sort the agents so the rewards that depend on other rewards are always evaluated later
|
||||
self._reward_calculation_order = topological_sort(graph)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from random import random
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
|
||||
def simulate_trial(p_of_success: float) -> bool:
|
||||
@@ -14,3 +15,81 @@ def simulate_trial(p_of_success: float) -> bool:
|
||||
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
|
||||
"""
|
||||
return random() < p_of_success
|
||||
|
||||
|
||||
def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool:
|
||||
"""Detect cycles in a directed graph.
|
||||
|
||||
Provide the graph as a dictionary that describes which nodes are linked. For example:
|
||||
{0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0
|
||||
{'a': ('b','c'), c:('b')} here there is no cycle
|
||||
|
||||
:param graph: a mapping from node to a set of nodes to which it is connected.
|
||||
:type graph: Mapping[Any, Iterable[Any]]
|
||||
:return: Whether the graph has any cycles
|
||||
:rtype: bool
|
||||
"""
|
||||
visited = set()
|
||||
currently_visiting = set()
|
||||
|
||||
def depth_first_search(node: Any) -> bool:
|
||||
"""Perform depth-first search (DFS) traversal to detect cycles starting from a given node."""
|
||||
if node in currently_visiting:
|
||||
return True # Cycle detected
|
||||
if node in visited:
|
||||
return False # Already visited, no need to explore further
|
||||
|
||||
visited.add(node)
|
||||
currently_visiting.add(node)
|
||||
|
||||
for neighbour in graph.get(node, []):
|
||||
if depth_first_search(neighbour):
|
||||
return True # Cycle detected
|
||||
|
||||
currently_visiting.remove(node)
|
||||
return False
|
||||
|
||||
# Start DFS traversal from each node
|
||||
for node in graph:
|
||||
if depth_first_search(node):
|
||||
return True # Cycle detected
|
||||
|
||||
return False # No cycles found
|
||||
|
||||
|
||||
def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]:
|
||||
"""
|
||||
Perform topological sorting on a directed graph.
|
||||
|
||||
This guarantees that if there's a directed edge from node A to node B, then A appears before B.
|
||||
|
||||
:param graph: A dictionary representing the directed graph, where keys are node identifiers
|
||||
and values are lists of outgoing edges from each node.
|
||||
:type graph: dict[int, list[Any]]
|
||||
|
||||
:return: A topologically sorted list of node identifiers.
|
||||
:rtype: list[Any]
|
||||
"""
|
||||
visited: set[Any] = set()
|
||||
stack: list[Any] = []
|
||||
|
||||
def dfs(node: Any) -> None:
|
||||
"""
|
||||
Depth-first search traversal to visit nodes and their neighbors.
|
||||
|
||||
:param node: The current node to visit.
|
||||
:type node: Any
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
for neighbour in graph.get(node, []):
|
||||
dfs(neighbour)
|
||||
stack.append(node)
|
||||
|
||||
# Perform DFS traversal from each node
|
||||
for node in graph:
|
||||
dfs(node)
|
||||
|
||||
# Reverse the stack and return it.
|
||||
return stack[::-1]
|
||||
|
||||
@@ -77,16 +77,17 @@ class PrimaiteIO:
|
||||
:type episode: int
|
||||
"""
|
||||
data = {}
|
||||
longest_history = max([len(hist) for hist in agent_actions])
|
||||
longest_history = max([len(hist) for hist in agent_actions.values()])
|
||||
for i in range(longest_history):
|
||||
data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}}
|
||||
data[i] = {"timestep": i, "episode": episode}
|
||||
data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i})
|
||||
|
||||
path = self.generate_agent_actions_save_path(episode=episode)
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
path.touch()
|
||||
_LOGGER.info(f"Saving agent action log to {path}")
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, fp=file, indent=1)
|
||||
json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump())
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
|
||||
Reference in New Issue
Block a user