From 70c1857bbc54414f48e42f065f791a8bec45caae Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 11 Oct 2023 15:49:41 +0100 Subject: [PATCH] Implement rewards for UC2 (draft) --- example_config.yaml | 13 ++- src/primaite/game/agent/observations.py | 30 +----- src/primaite/game/agent/rewards.py | 123 +++++++++++++++++++++--- src/primaite/game/agent/utils.py | 29 ++++++ 4 files changed, 152 insertions(+), 43 deletions(-) create mode 100644 src/primaite/game/agent/utils.py diff --git a/example_config.yaml b/example_config.yaml index a35c82e0..b700da5c 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -445,7 +445,18 @@ game_config: reward_function: reward_components: - - type: DUMMY + - type: DATABASE_FILE_INTEGRITY + weight: 0.5 + options: + node_ref: database_server + folder_name: database + file_name: database.db + - type: WEB_SERVER_404_PENALTY + weight: 0.5 + options: + node_ref: web_server + service_ref: web_server_database_client + agent_settings: # ... diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index a5a5fc77..ba1e8e66 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -5,41 +5,13 @@ from gymnasium import spaces from pydantic import BaseModel from primaite.simulator.sim_container import Simulation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE from primaite import getLogger _LOGGER = getLogger(__name__) if TYPE_CHECKING: from primaite.game.session import PrimaiteSession -NOT_PRESENT_IN_STATE = object() -""" -Need an object to return when the sim state does not contain a requested value. Cannot use None because sometimes -the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is a sentinel for this purpose. -""" - - -def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: - """ - Access an item from a deeply dictionary with a list of keys. - - For example, if the dictionary is {1: 'a', 2: {3: {4: 'b'}}}, then the key [2, 3, 4] would return 'b', and the key - [2, 3] would return {4: 'b'}. Raises a KeyError if specified key does not exist at any level of nesting. - - :param dictionary: Deeply nested dictionary - :type dictionary: Dict - :param keys: List of dict keys used to traverse the nested dict. Each item corresponds to one level of depth. - :type keys: List[Hashable] - :return: The value in the dictionary - :rtype: Any - """ - key_list = [*keys] # copy keys to a new list to prevent editing original list - if len(key_list) == 0: - return dictionary - k = key_list.pop(0) - if k not in dictionary: - return NOT_PRESENT_IN_STATE - return access_from_nested_dict(dictionary[k], key_list) - class AbstractObservation(ABC): @abstractmethod diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 18925edc..b7a4bb24 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -1,34 +1,131 @@ +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, TYPE_CHECKING +from primaite import getLogger +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession class AbstractReward: - def __init__(self): - ... @abstractmethod def calculate(self, state: Dict) -> float: - return 0.3 + return 0.0 + + @abstractmethod + @classmethod + def from_config(cls, config:dict) -> "AbstractReward": + return cls() class DummyReward(AbstractReward): def calculate(self, state: Dict) -> float: - return -0.1 + return 0.0 + + @classmethod + def from_config(cls, config: dict) -> "DummyReward": + return cls() + +class DatabaseFileIntegrity(AbstractReward): + def __init__(self, node_uuid:str, folder_name:str, file_name:str) -> None: + self.location_in_state = ["network", "node", node_uuid, "file_system", ""] + + def calculate(self, state: Dict) -> float: + database_file_state = access_from_nested_dict(state, self.location_in_state) + health_status = database_file_state['health_status'] + if health_status == "corrupted": + return -1 + elif health_status == "good": + return 1 + else: + return 0 + + @classmethod + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity": + node_ref = config.get("node_ref") + folder_name = config.get("folder_name") + file_name = config.get("file_name") + if not (node_ref): + _LOGGER.error(f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified") + return DummyReward() #TODO: better error handling + if not folder_name: + _LOGGER.error(f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified") + return DummyReward() # TODO: better error handling + if not file_name: + _LOGGER.error(f"{cls.__name__} could not be initialised from config because file_name parameter was not specified") + return DummyReward() # TODO: better error handling + node_uuid = session.ref_map_nodes[node_ref].uuid + if not node_uuid: + _LOGGER.error(f"{cls.__name__} could not be initialised from config because the referenced node could not be found in the simulation") + return DummyReward() # TODO: better error handling + + return cls(node_uuid = node_uuid, folder_name=folder_name, file_name=file_name) + +class WebServer404Penalty(AbstractReward): + def __init__(self, node_uuid:str, service_uuid:str) -> None: + self.location_in_state = ['network','node', node_uuid, 'services', service_uuid] + + def calculate(self, state: Dict) -> float: + web_service_state = access_from_nested_dict(state, self.location_in_state) + most_recent_return_code = web_service_state['most_recent_return_code'] + if most_recent_return_code == 200: + return 1 + elif most_recent_return_code == 404: + return -1 + else: + return 0 + + @classmethod + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty": + node_ref = config.get("node_ref") + service_ref = config.get("service_ref") + if not (node_ref and service_ref): + msg = f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not found in reward config." + _LOGGER.warn(msg) + return DummyReward() #TODO: should we error out with incorrect inputs? Probably! + node_uuid = session.ref_map_nodes[node_ref].uuid + service_uuid = session.ref_map_services[service_ref].uuid + if not (node_uuid and service_uuid): + msg = f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not found in the simulator." + _LOGGER.warn(msg) + return DummyReward() # TODO: consider erroring here as well + + return cls(node_uuid=node_uuid, service_uuid=service_uuid) class RewardFunction: - __rew_class_identifiers: Dict[str, type[AbstractReward]] = {"DUMMY": DummyReward} + __rew_class_identifiers: Dict[str, type[AbstractReward]] = { + "DUMMY": DummyReward, + "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, + "WEB_SERVER_404_PENALTY": WebServer404Penalty, + } - def __init__(self, reward_function: AbstractReward): - self.reward: AbstractReward = reward_function + def __init__(self): + self.reward_components: List[Tuple[AbstractReward, float]] = [] + "attribute reward_components keeps track of reward components and the weights assigned to each." + + def regsiter_component(self, component:AbstractReward, weight:float=1.0) -> None: + self.reward_components.append((component, weight)) def calculate(self, state: Dict) -> float: - return self.reward.calculate(state) + total = 0.0 + for comp_and_weight in self.reward_components: + comp = comp_and_weight[0] + weight = comp_and_weight[1] + total += weight * comp.calculate(state=state) + return total @classmethod - def from_config(cls, cfg: Dict) -> "RewardFunction": - for rew_component_cfg in cfg["reward_components"]: + def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction": + new = cls() + + for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] - rew_component = cls.__rew_class_identifiers[rew_type]() - new = cls(reward_function=rew_component) + weight = rew_component_cfg["weight"] + rew_class = cls.__rew_class_identifiers[rew_type] + rew_instance = rew_class.from_config(config=rew_component_cfg.get('options',{}), session=session) + new.regsiter_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/agent/utils.py b/src/primaite/game/agent/utils.py new file mode 100644 index 00000000..ad6dbefe --- /dev/null +++ b/src/primaite/game/agent/utils.py @@ -0,0 +1,29 @@ +from typing import Dict, Sequence, Hashable, Any + +NOT_PRESENT_IN_STATE = object() +""" +Need an object to return when the sim state does not contain a requested value. Cannot use None because sometimes +the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is a sentinel for this purpose. +""" + +def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: + """ + Access an item from a deeply dictionary with a list of keys. + + For example, if the dictionary is {1: 'a', 2: {3: {4: 'b'}}}, then the key [2, 3, 4] would return 'b', and the key + [2, 3] would return {4: 'b'}. Raises a KeyError if specified key does not exist at any level of nesting. + + :param dictionary: Deeply nested dictionary + :type dictionary: Dict + :param keys: List of dict keys used to traverse the nested dict. Each item corresponds to one level of depth. + :type keys: List[Hashable] + :return: The value in the dictionary + :rtype: Any + """ + key_list = [*keys] # copy keys to a new list to prevent editing original list + if len(key_list) == 0: + return dictionary + k = key_list.pop(0) + if k not in dictionary: + return NOT_PRESENT_IN_STATE + return access_from_nested_dict(dictionary[k], key_list) \ No newline at end of file