#2913: Update reward classes to work with pydantic.
This commit is contained in:
@@ -274,28 +274,34 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
reward: float = 0.0
|
||||
location_in_state: List[str] = [""]
|
||||
_node: str = node_hostname
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebpageUnavailablePenalty."""
|
||||
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
reward: float = 0.0
|
||||
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
# def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
# """
|
||||
# Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node which has the web browser.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
# :param node_hostname: Hostname of the node which has the web browser.
|
||||
# :type node_hostname: str
|
||||
# :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
# the reward if there were any responses this timestep.
|
||||
# :type sticky: bool
|
||||
# """
|
||||
# self._node: str = node_hostname
|
||||
# self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
# self.sticky: bool = sticky
|
||||
# self.reward: float = 0.0
|
||||
# """Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -311,6 +317,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
self.location_in_state: List[str] = ["network", "nodes", self.node_hostname, "applications", "WebBrowser"]
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -364,6 +371,10 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
node_hostname: str = ""
|
||||
_node: str = node_hostname
|
||||
sticky: bool = True
|
||||
reward: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
|
||||
@@ -371,21 +382,21 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
node_hostname: str
|
||||
sticky: bool = True
|
||||
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
# def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
# """
|
||||
# Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
# :param node_hostname: Hostname of the node where the database client sits.
|
||||
# :type node_hostname: str
|
||||
# :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
# the reward if there were any responses this timestep.
|
||||
# :type sticky: bool
|
||||
# """
|
||||
# self._node: str = node_hostname
|
||||
# self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
# self.sticky: bool = sticky
|
||||
# self.reward: float = 0.0
|
||||
# """Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -438,37 +449,38 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
|
||||
class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
agent_name: str
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for SharedReward."""
|
||||
|
||||
agent_name: str
|
||||
|
||||
def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
# def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
# """
|
||||
# Initialise the shared reward.
|
||||
|
||||
# The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
|
||||
# correctly.
|
||||
|
||||
# :param agent_name: The name whose reward is an input
|
||||
# :type agent_name: Optional[str]
|
||||
# """
|
||||
# # self.agent_name = agent_name
|
||||
# """Agent whose reward to track."""
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Initialise the shared reward.
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
|
||||
correctly.
|
||||
|
||||
:param agent_name: The name whose reward is an input
|
||||
:type agent_name: Optional[str]
|
||||
SharedReward should not be used until the game layer replaces self.callback with a reference to the
|
||||
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
|
||||
an error.
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
"""Agent whose reward to track."""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
|
||||
SharedReward should not be used until the game layer replaces self.callback with a reference to the
|
||||
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
|
||||
an error.
|
||||
"""
|
||||
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
|
||||
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it.
|
||||
|
||||
@@ -74,7 +74,7 @@ def test_uc2_rewards(game_and_agent):
|
||||
ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2
|
||||
)
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
|
||||
comp = GreenAdminDatabaseUnreachablePenalty(node_hostname="client_1")
|
||||
|
||||
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
|
||||
response = game.simulation.apply_request(request)
|
||||
|
||||
Reference in New Issue
Block a user