#2913: Update reward classes to work with pydantic.

This commit is contained in:
Nick Todd
2024-10-21 17:11:11 +01:00
parent bbcbb26f5e
commit 0cf8e20e6d
2 changed files with 62 additions and 50 deletions

View File

@@ -274,28 +274,34 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
"""Penalises the agent when the web browser fails to fetch a webpage."""
node_hostname: str = ""
sticky: bool = True
reward: float = 0.0
location_in_state: List[str] = [""]
_node: str = node_hostname
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebpageUnavailablePenalty."""
node_hostname: str = ""
sticky: bool = True
reward: float = 0.0
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
# def __init__(self, node_hostname: str, sticky: bool = True) -> None:
# """
# Initialise the reward component.
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
# :param node_hostname: Hostname of the node which has the web browser.
# :type node_hostname: str
# :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
# the reward if there were any responses this timestep.
# :type sticky: bool
# """
# self._node: str = node_hostname
# self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
# self.sticky: bool = sticky
# self.reward: float = 0.0
# """Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -311,6 +317,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
:return: Reward value
:rtype: float
"""
self.location_in_state: List[str] = ["network", "nodes", self.node_hostname, "applications", "WebBrowser"]
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE:
@@ -364,6 +371,10 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
"""Penalises the agent when the green db clients fail to connect to the database."""
node_hostname: str = ""
_node: str = node_hostname
sticky: bool = True
reward: float = 0.0
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
@@ -371,21 +382,21 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
node_hostname: str
sticky: bool = True
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
# def __init__(self, node_hostname: str, sticky: bool = True) -> None:
# """
# Initialise the reward component.
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
# :param node_hostname: Hostname of the node where the database client sits.
# :type node_hostname: str
# :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
# the reward if there were any responses this timestep.
# :type sticky: bool
# """
# self._node: str = node_hostname
# self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
# self.sticky: bool = sticky
# self.reward: float = 0.0
# """Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -438,37 +449,38 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
class SharedReward(AbstractReward, identifier="SharedReward"):
"""Adds another agent's reward to the overall reward."""
agent_name: str
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for SharedReward."""
agent_name: str
def __init__(self, agent_name: Optional[str] = None) -> None:
# def __init__(self, agent_name: Optional[str] = None) -> None:
# """
# Initialise the shared reward.
# The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
# correctly.
# :param agent_name: The name whose reward is an input
# :type agent_name: Optional[str]
# """
# # self.agent_name = agent_name
# """Agent whose reward to track."""
def default_callback(agent_name: str) -> Never:
"""
Initialise the shared reward.
Default callback to prevent calling this reward until it's properly initialised.
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
correctly.
:param agent_name: The name whose reward is an input
:type agent_name: Optional[str]
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
self.agent_name = agent_name
"""Agent whose reward to track."""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
def default_callback(agent_name: str) -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it.

View File

@@ -74,7 +74,7 @@ def test_uc2_rewards(game_and_agent):
ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2
)
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
comp = GreenAdminDatabaseUnreachablePenalty(node_hostname="client_1")
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
response = game.simulation.apply_request(request)