From b849ea6312c21366256fc297297017112dab1d9b Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 4 Nov 2024 17:41:43 +0000 Subject: [PATCH] #2913: Remove from_config() and refactor (WIP). --- src/primaite/game/agent/rewards.py | 167 +++++------------- .../game_layer/test_rewards.py | 11 +- 2 files changed, 48 insertions(+), 130 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 2386bed5..03764e4b 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -46,6 +46,14 @@ WhereType = Optional[Iterable[Union[str, int]]] class AbstractReward(BaseModel): """Base class for reward function components.""" + config: "AbstractReward.ConfigSchema" + + # def __init__(self, schema_name, **kwargs): + # super.__init__(self, **kwargs) + # # Create ConfigSchema class + # self.config_class = type(schema_name, (BaseModel, ABC), **kwargs) + # self.config = self.config_class() + class ConfigSchema(BaseModel, ABC): """Config schema for AbstractReward.""" @@ -56,7 +64,7 @@ class AbstractReward(BaseModel): def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) if identifier in cls._registry: - raise ValueError(f"Duplicate node adder {identifier}") + raise ValueError(f"Duplicate reward {identifier}") cls._registry[identifier] = cls @classmethod @@ -70,9 +78,10 @@ class AbstractReward(BaseModel): """ if config["type"] not in cls._registry: raise ValueError(f"Invalid reward type {config['type']}") - adder_class = cls._registry[config["type"]] - adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config)) - return cls + reward_class = cls._registry[config["type"]] + reward_config = reward_class.ConfigSchema(**config) + reward_class(config=reward_config) + return reward_class @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -103,30 +112,18 @@ class DummyReward(AbstractReward, identifier="DummyReward"): """ return 0.0 - @classmethod - def from_config(cls, config: dict) -> "DummyReward": - """Create a reward function component from a config dictionary. - - :param config: dict of options for the reward component's constructor. Should be empty. - :type config: dict - :return: The reward component. - :rtype: DummyReward - """ - return cls() - class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" - node_hostname: str - folder_name: str - file_name: str + config: "DatabaseFileIntegrity.ConfigSchema" location_in_state: List[str] = [""] reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for DatabaseFileIntegrity.""" + type: str = "DatabaseFileIntegrity" node_hostname: str folder_name: str file_name: str @@ -144,12 +141,12 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): self.location_in_state = [ "network", "nodes", - self.node_hostname, + self.config.node_hostname, "file_system", "folders", - self.folder_name, + self.config.folder_name, "files", - self.file_name, + self.config.file_name, ] database_file_state = access_from_nested_dict(state, self.location_in_state) @@ -168,38 +165,18 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): else: return 0 - @classmethod - def from_config(cls, config: Dict) -> "DatabaseFileIntegrity": - """Create a reward function component from a config dictionary. - - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: DatabaseFileIntegrity - """ - node_hostname = config.get("node_hostname") - folder_name = config.get("folder_name") - file_name = config.get("file_name") - if not (node_hostname and folder_name and file_name): - msg = f"{cls.__name__} could not be initialised with parameters {config}" - _LOGGER.error(msg) - raise ValueError(msg) - - return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name) - class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" - node_hostname: str - service_name: str - sticky: bool = True + config: "WebServer404Penalty.ConfigSchema" location_in_state: List[str] = [""] reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebServer404Penalty.""" + type: str = "WebServer404Penalty" node_hostname: str service_name: str sticky: bool = True @@ -217,9 +194,9 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): self.location_in_state = [ "network", "nodes", - self.node_hostname, + self.config.node_hostname, "services", - self.service_name, + self.config.service_name, ] web_service_state = access_from_nested_dict(state, self.location_in_state) @@ -242,43 +219,20 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): return self.reward - @classmethod - def from_config(cls, config: Dict) -> "WebServer404Penalty": - """Create a reward function component from a config dictionary. - - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: WebServer404Penalty - """ - node_hostname = config.get("node_hostname") - service_name = config.get("service_name") - if not (node_hostname and service_name): - msg = ( - f"{cls.__name__} could not be initialised from config because node_name and service_ref were not " - "found in reward config." - ) - _LOGGER.warning(msg) - raise ValueError(msg) - sticky = config.get("sticky", True) - - return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) - class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" - node_hostname: str = "" - sticky: bool = True - reward: float = 0.0 - location_in_state: List[str] = [""] + config: "WebpageUnavailablePenalty.ConfigSchema" + reward: float = 0.0 # XXX: Private attribute? + location_in_state: List[str] = [""] # Calculate in __init__()? class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" + type: str = "WebpageUnavailablePenalty" node_hostname: str = "" sticky: bool = True - reward: float = 0.0 def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -297,7 +251,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe self.location_in_state = [ "network", "nodes", - self.node_hostname, + self.config.node_hostname, "applications", "WebBrowser", ] @@ -310,14 +264,14 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe request_attempted = last_action_response.request == [ "network", "node", - self.node_hostname, + self.config.node_hostname, "application", "WebBrowser", "execute", ] # skip calculating if sticky and no new codes, reusing last step value - if not request_attempted and self.sticky: + if not request_attempted and self.config.sticky: return self.reward if last_action_response.response.status != "success": @@ -339,29 +293,17 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe return self.reward - @classmethod - def from_config(cls, config: dict) -> AbstractReward: - """ - Build the reward component object from config. - - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" - node_hostname: str = "" - sticky: bool = True + config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" + type: str = "GreenAdminDatabaseUnreachablePenalty" node_hostname: str sticky: bool = True @@ -383,7 +325,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi request_attempted = last_action_response.request == [ "network", "node", - self.node_hostname, + self.config.node_hostname, "application", "DatabaseClient", "execute", @@ -392,7 +334,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi if request_attempted: # if agent makes request, always recalculate fresh value last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status} self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 - elif not self.sticky: # if no new request and not sticky, set reward to 0 + elif not self.config.sticky: # if no new request and not sticky, set reward to 0 last_action_response.reward_info = {"connection_attempt_status": "n/a"} self.reward = 0.0 else: # if no new request and sticky, reuse reward value from last step @@ -401,27 +343,16 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi return self.reward - @classmethod - def from_config(cls, config: Dict) -> AbstractReward: - """ - Build the reward component object from config. - - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" - agent_name: str + config: "SharedReward.ConfigSchema" class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for SharedReward.""" + type: str = "SharedReward" agent_name: str def default_callback(agent_name: str) -> Never: @@ -447,29 +378,18 @@ class SharedReward(AbstractReward, identifier="SharedReward"): :return: Reward value :rtype: float """ - return self.callback(self.agent_name) - - @classmethod - def from_config(cls, config: Dict) -> "SharedReward": - """ - Build the SharedReward object from config. - - :param config: Configuration dictionary - :type config: Dict - """ - agent_name = config.get("agent_name") - return cls(agent_name=agent_name) + return self.callback(self.config.agent_name) class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" - action_penalty: float = -1.0 - do_nothing_penalty: float = 0.0 + config: "ActionPenalty.ConfigSchema" class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for ActionPenalty.""" + type: str = "ActionPenalty" action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 @@ -484,16 +404,9 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): :rtype: float """ if last_action_response.action == "DONOTHING": - return self.do_nothing_penalty + return self.config.do_nothing_penalty else: - return self.action_penalty - - @classmethod - def from_config(cls, config: Dict) -> "ActionPenalty": - """Build the ActionPenalty object from config.""" - action_penalty = config.get("action_penalty", -1.0) - do_nothing_penalty = config.get("do_nothing_penalty", 0.0) - return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) + return self.config.action_penalty class RewardFunction: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index bf707feb..d4236d1b 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -23,7 +23,9 @@ def test_WebpageUnavailablePenalty(game_and_agent): # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent agent: ControlledAgent - comp = WebpageUnavailablePenalty(node_hostname="client_1") + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = WebpageUnavailablePenalty(config=schema) + client_1 = game.simulation.network.get_node_by_hostname("client_1") browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() @@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent): ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) - comp = GreenAdminDatabaseUnreachablePenalty(node_hostname="client_1") + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = GreenAdminDatabaseUnreachablePenalty(config=schema) request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] response = game.simulation.apply_request(request) @@ -147,7 +150,9 @@ def test_action_penalty(): """Test that the action penalty is correctly applied when agent performs any action""" # Create an ActionPenalty Reward - Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + # Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + Penalty = ActionPenalty(schema) # Assert that penalty is applied if action isn't DONOTHING reward_value = Penalty.calculate(