From f95ba8cbbcd92779a66684c6bd2dbbf51052e5d2 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 22 Oct 2024 11:01:35 +0100 Subject: [PATCH] #2913: Fix remaining pydantic errors. --- src/primaite/game/agent/rewards.py | 141 ++++++------------ .../_game/_agent/test_sticky_rewards.py | 12 +- 2 files changed, 48 insertions(+), 105 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 9777441b..1f870e83 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -28,7 +28,7 @@ the structure: ``` """ from abc import ABC, abstractmethod -from typing import Any, ClassVar, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union from pydantic import BaseModel from typing_extensions import Never @@ -118,6 +118,12 @@ class DummyReward(AbstractReward, identifier="DummyReward"): class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" + node_hostname: str + folder_name: str + file_name: str + location_in_state: List[str] = [""] + reward: float = 0.0 + class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for DatabaseFileIntegrity.""" @@ -125,27 +131,6 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): folder_name: str file_name: str - def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: - """Initialise the reward component. - - :param node_hostname: Hostname of the node which contains the database file. - :type node_hostname: str - :param folder_name: folder which contains the database file. - :type folder_name: str - :param file_name: name of the database file. - :type file_name: str - """ - self.location_in_state = [ - "network", - "nodes", - node_hostname, - "file_system", - "folders", - folder_name, - "files", - file_name, - ] - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -156,6 +141,17 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.node_hostname, + "file_system", + "folders", + self.folder_name, + "files", + self.file_name, + ] + database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: _LOGGER.debug( @@ -195,6 +191,12 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" + node_hostname: str + service_name: str + sticky: bool = True + location_in_state: List[str] = [""] + reward: float = 0.0 + class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebServer404Penalty.""" @@ -202,22 +204,6 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): service_name: str sticky: bool = True - def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: - """Initialise the reward component. - - :param node_hostname: Hostname of the node which contains the web server service. - :type node_hostname: str - :param service_name: Name of the web server service. - :type service_name: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -228,6 +214,13 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.node_hostname, + "services", + self.service_name, + ] web_service_state = access_from_nested_dict(state, self.location_in_state) # if webserver is no longer installed on the node, return 0 @@ -274,6 +267,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" + node_hostname: str = "" sticky: bool = True reward: float = 0.0 @@ -287,22 +281,6 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe sticky: bool = True reward: float = 0.0 - # def __init__(self, node_hostname: str, sticky: bool = True) -> None: - # """ - # Initialise the reward component. - - # :param node_hostname: Hostname of the node which has the web browser. - # :type node_hostname: str - # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - # the reward if there were any responses this timestep. - # :type sticky: bool - # """ - # self._node: str = node_hostname - # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - # self.sticky: bool = sticky - # self.reward: float = 0.0 - # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -317,7 +295,13 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe :return: Reward value :rtype: float """ - self.location_in_state: List[str] = ["network", "nodes", self.node_hostname, "applications", "WebBrowser"] + self.location_in_state = [ + "network", + "nodes", + self.node_hostname, + "applications", + "WebBrowser", + ] web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE: @@ -371,6 +355,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" + node_hostname: str = "" _node: str = node_hostname sticky: bool = True @@ -382,22 +367,6 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi node_hostname: str sticky: bool = True - # def __init__(self, node_hostname: str, sticky: bool = True) -> None: - # """ - # Initialise the reward component. - - # :param node_hostname: Hostname of the node where the database client sits. - # :type node_hostname: str - # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - # the reward if there were any responses this timestep. - # :type sticky: bool - # """ - # self._node: str = node_hostname - # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - # self.sticky: bool = sticky - # self.reward: float = 0.0 - # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -449,6 +418,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" + agent_name: str class ConfigSchema(AbstractReward.ConfigSchema): @@ -456,19 +426,6 @@ class SharedReward(AbstractReward, identifier="SharedReward"): agent_name: str - # def __init__(self, agent_name: Optional[str] = None) -> None: - # """ - # Initialise the shared reward. - - # The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work - # correctly. - - # :param agent_name: The name whose reward is an input - # :type agent_name: Optional[str] - # """ - # # self.agent_name = agent_name - # """Agent whose reward to track.""" - def default_callback(agent_name: str) -> Never: """ Default callback to prevent calling this reward until it's properly initialised. @@ -508,6 +465,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" + action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 @@ -517,21 +475,6 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 - # def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: - # """ - # Initialise the reward. - - # Reward or penalise agents for doing nothing or taking actions. - - # :param action_penalty: Reward to give agents for taking any action except DONOTHING - # :type action_penalty: float - # :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action - # :type do_nothing_penalty: float - # """ - # super().__init__(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) - # self.action_penalty = action_penalty - # self.do_nothing_penalty = do_nothing_penalty - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied. diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py index 58f0fcc1..2ad1a322 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -11,7 +11,7 @@ from primaite.interface.request import RequestResponse class TestWebServer404PenaltySticky: def test_non_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=False) + reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=False) # no response codes yet, reward is 0 codes = [] @@ -38,7 +38,7 @@ class TestWebServer404PenaltySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=True) + reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=True) # no response codes yet, reward is 0 codes = [] @@ -67,7 +67,7 @@ class TestWebServer404PenaltySticky: class TestWebpageUnavailabilitySticky: def test_non_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=False) + reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=False) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -127,7 +127,7 @@ class TestWebpageUnavailabilitySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=True) + reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=True) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -188,7 +188,7 @@ class TestWebpageUnavailabilitySticky: class TestGreenAdminDatabaseUnreachableSticky: def test_non_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False) + reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=False) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -244,7 +244,7 @@ class TestGreenAdminDatabaseUnreachableSticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True) + reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=True) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"]