From fe6a8e6e97cb7a3837091950d7038b52fe1ffbff Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 17 Oct 2024 13:24:57 +0100 Subject: [PATCH 01/24] #2913: Initial commit of new AbstractReward class. --- src/primaite/game/agent/rewards.py | 61 +++++++++++++++++++----------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 1de34b40..0db3cc28 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -27,9 +27,10 @@ the structure: service_ref: web_server_database_client ``` """ -from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from pydantic import BaseModel from typing_extensions import Never from primaite import getLogger @@ -42,9 +43,35 @@ _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] -class AbstractReward: +class AbstractReward(BaseModel): """Base class for reward function components.""" + class ConfigSchema(BaseModel, ABC): + """Config schema for AbstractReward.""" + + type: str + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate node adder {identifier}") + cls._registry[identifier] = cls + + @classmethod + @abstractmethod + def from_config(cls, config: Dict) -> "AbstractReward": + """Create a reward function component from a config dictionary. + + :param config: dict of options for the reward component's constructor + :type config: dict + :return: The reward component. + :rtype: AbstractReward + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid reward type {config['type']}") + + return cls + @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -58,21 +85,9 @@ class AbstractReward: """ return 0.0 - @classmethod - @abstractmethod - def from_config(cls, config: dict) -> "AbstractReward": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: dict - :return: The reward component. - :rtype: AbstractReward - """ - return cls() - - -class DummyReward(AbstractReward): - """Dummy reward function component which always returns 0.""" +class DummyReward(AbstractReward, identifier="DummyReward"): + """Dummy reward function component which always returns 0.0""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -98,7 +113,7 @@ class DummyReward(AbstractReward): return cls() -class DatabaseFileIntegrity(AbstractReward): +class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: @@ -168,7 +183,7 @@ class DatabaseFileIntegrity(AbstractReward): return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name) -class WebServer404Penalty(AbstractReward): +class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: @@ -241,7 +256,7 @@ class WebServer404Penalty(AbstractReward): return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) -class WebpageUnavailablePenalty(AbstractReward): +class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" def __init__(self, node_hostname: str, sticky: bool = True) -> None: @@ -325,7 +340,7 @@ class WebpageUnavailablePenalty(AbstractReward): return cls(node_hostname=node_hostname, sticky=sticky) -class GreenAdminDatabaseUnreachablePenalty(AbstractReward): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" def __init__(self, node_hostname: str, sticky: bool = True) -> None: @@ -393,7 +408,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return cls(node_hostname=node_hostname, sticky=sticky) -class SharedReward(AbstractReward): +class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" def __init__(self, agent_name: Optional[str] = None) -> None: @@ -446,7 +461,7 @@ class SharedReward(AbstractReward): return cls(agent_name=agent_name) -class ActionPenalty(AbstractReward): +class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: From 419a86114d0b6eca4d2da343c427ddddf1b297cf Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 17 Oct 2024 16:35:13 +0100 Subject: [PATCH 02/24] #2913: Now with ConfigSchemas. --- src/primaite/game/agent/rewards.py | 39 +++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 0db3cc28..4198af27 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -87,7 +87,7 @@ class AbstractReward(BaseModel): class DummyReward(AbstractReward, identifier="DummyReward"): - """Dummy reward function component which always returns 0.0""" + """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -116,6 +116,13 @@ class DummyReward(AbstractReward, identifier="DummyReward"): class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + node_hostname: str + folder_name: str + file_name: str + def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: """Initialise the reward component. @@ -186,6 +193,13 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebServer404Penalty.""" + + node_hostname: str + service_name: str + sticky: bool = True + def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: """Initialise the reward component. @@ -259,6 +273,12 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebpageUnavailablePenalty.""" + + node_hostname: str + sticky: bool = True + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. @@ -343,6 +363,12 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" + + node_hostname: str + sticky: bool = True + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. @@ -411,6 +437,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for SharedReward.""" + + agent_name: str + def __init__(self, agent_name: Optional[str] = None) -> None: """ Initialise the shared reward. @@ -464,6 +495,12 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for ActionPenalty.""" + + action_penalty: float = -1.0 + do_nothing_penalty: float = 0.0 + def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: """ Initialise the reward. From bbcbb26f5edd79202bfed16b46a7bc3573b60397 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 21 Oct 2024 14:43:51 +0100 Subject: [PATCH 03/24] #2913: Fix ActionPenalty. --- src/primaite/game/agent/rewards.py | 35 +++++++++++++++++------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 4198af27..1158d919 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -28,7 +28,7 @@ the structure: ``` """ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, ClassVar, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union from pydantic import BaseModel from typing_extensions import Never @@ -51,6 +51,8 @@ class AbstractReward(BaseModel): type: str + _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) if identifier in cls._registry: @@ -58,7 +60,6 @@ class AbstractReward(BaseModel): cls._registry[identifier] = cls @classmethod - @abstractmethod def from_config(cls, config: Dict) -> "AbstractReward": """Create a reward function component from a config dictionary. @@ -69,7 +70,8 @@ class AbstractReward(BaseModel): """ if config["type"] not in cls._registry: raise ValueError(f"Invalid reward type {config['type']}") - + adder_class = cls._registry[config["type"]] + adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config)) return cls @abstractmethod @@ -276,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" - node_hostname: str + node_hostname: str = "" sticky: bool = True def __init__(self, node_hostname: str, sticky: bool = True) -> None: @@ -494,6 +496,8 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" + action_penalty: float = -1.0 + do_nothing_penalty: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for ActionPenalty.""" @@ -501,19 +505,20 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 - def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: - """ - Initialise the reward. + # def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: + # """ + # Initialise the reward. - Reward or penalise agents for doing nothing or taking actions. + # Reward or penalise agents for doing nothing or taking actions. - :param action_penalty: Reward to give agents for taking any action except DONOTHING - :type action_penalty: float - :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action - :type do_nothing_penalty: float - """ - self.action_penalty = action_penalty - self.do_nothing_penalty = do_nothing_penalty + # :param action_penalty: Reward to give agents for taking any action except DONOTHING + # :type action_penalty: float + # :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action + # :type do_nothing_penalty: float + # """ + # super().__init__(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) + # self.action_penalty = action_penalty + # self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied. From 0cf8e20e6da5d8941dec6080ca687649437106a0 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 21 Oct 2024 17:11:11 +0100 Subject: [PATCH 04/24] #2913: Update reward classes to work with pydantic. --- src/primaite/game/agent/rewards.py | 110 ++++++++++-------- .../game_layer/test_rewards.py | 2 +- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 1158d919..9777441b 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -274,28 +274,34 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" + node_hostname: str = "" + sticky: bool = True + reward: float = 0.0 + location_in_state: List[str] = [""] + _node: str = node_hostname class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" node_hostname: str = "" sticky: bool = True + reward: float = 0.0 - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + # def __init__(self, node_hostname: str, sticky: bool = True) -> None: + # """ + # Initialise the reward component. - :param node_hostname: Hostname of the node which has the web browser. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + # :param node_hostname: Hostname of the node which has the web browser. + # :type node_hostname: str + # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + # the reward if there were any responses this timestep. + # :type sticky: bool + # """ + # self._node: str = node_hostname + # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + # self.sticky: bool = sticky + # self.reward: float = 0.0 + # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -311,6 +317,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe :return: Reward value :rtype: float """ + self.location_in_state: List[str] = ["network", "nodes", self.node_hostname, "applications", "WebBrowser"] web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE: @@ -364,6 +371,10 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" + node_hostname: str = "" + _node: str = node_hostname + sticky: bool = True + reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" @@ -371,21 +382,21 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi node_hostname: str sticky: bool = True - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + # def __init__(self, node_hostname: str, sticky: bool = True) -> None: + # """ + # Initialise the reward component. - :param node_hostname: Hostname of the node where the database client sits. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + # :param node_hostname: Hostname of the node where the database client sits. + # :type node_hostname: str + # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + # the reward if there were any responses this timestep. + # :type sticky: bool + # """ + # self._node: str = node_hostname + # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + # self.sticky: bool = sticky + # self.reward: float = 0.0 + # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -438,37 +449,38 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" + agent_name: str class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for SharedReward.""" agent_name: str - def __init__(self, agent_name: Optional[str] = None) -> None: + # def __init__(self, agent_name: Optional[str] = None) -> None: + # """ + # Initialise the shared reward. + + # The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work + # correctly. + + # :param agent_name: The name whose reward is an input + # :type agent_name: Optional[str] + # """ + # # self.agent_name = agent_name + # """Agent whose reward to track.""" + + def default_callback(agent_name: str) -> Never: """ - Initialise the shared reward. + Default callback to prevent calling this reward until it's properly initialised. - The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work - correctly. - - :param agent_name: The name whose reward is an input - :type agent_name: Optional[str] + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. """ - self.agent_name = agent_name - """Agent whose reward to track.""" + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - def default_callback(agent_name: str) -> Never: - """ - Default callback to prevent calling this reward until it's properly initialised. - - SharedReward should not be used until the game layer replaces self.callback with a reference to the - function that retrieves the desired agent's reward. Therefore, we define this default callback that raises - an error. - """ - raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - - self.callback: Callable[[str], float] = default_callback - """Method that retrieves an agent's current reward given the agent's name.""" + callback: Callable[[str], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it. diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 0005b508..bf707feb 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -74,7 +74,7 @@ def test_uc2_rewards(game_and_agent): ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) - comp = GreenAdminDatabaseUnreachablePenalty("client_1") + comp = GreenAdminDatabaseUnreachablePenalty(node_hostname="client_1") request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] response = game.simulation.apply_request(request) From f95ba8cbbcd92779a66684c6bd2dbbf51052e5d2 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 22 Oct 2024 11:01:35 +0100 Subject: [PATCH 05/24] #2913: Fix remaining pydantic errors. --- src/primaite/game/agent/rewards.py | 141 ++++++------------ .../_game/_agent/test_sticky_rewards.py | 12 +- 2 files changed, 48 insertions(+), 105 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 9777441b..1f870e83 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -28,7 +28,7 @@ the structure: ``` """ from abc import ABC, abstractmethod -from typing import Any, ClassVar, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union from pydantic import BaseModel from typing_extensions import Never @@ -118,6 +118,12 @@ class DummyReward(AbstractReward, identifier="DummyReward"): class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" + node_hostname: str + folder_name: str + file_name: str + location_in_state: List[str] = [""] + reward: float = 0.0 + class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for DatabaseFileIntegrity.""" @@ -125,27 +131,6 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): folder_name: str file_name: str - def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: - """Initialise the reward component. - - :param node_hostname: Hostname of the node which contains the database file. - :type node_hostname: str - :param folder_name: folder which contains the database file. - :type folder_name: str - :param file_name: name of the database file. - :type file_name: str - """ - self.location_in_state = [ - "network", - "nodes", - node_hostname, - "file_system", - "folders", - folder_name, - "files", - file_name, - ] - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -156,6 +141,17 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.node_hostname, + "file_system", + "folders", + self.folder_name, + "files", + self.file_name, + ] + database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: _LOGGER.debug( @@ -195,6 +191,12 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" + node_hostname: str + service_name: str + sticky: bool = True + location_in_state: List[str] = [""] + reward: float = 0.0 + class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebServer404Penalty.""" @@ -202,22 +204,6 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): service_name: str sticky: bool = True - def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: - """Initialise the reward component. - - :param node_hostname: Hostname of the node which contains the web server service. - :type node_hostname: str - :param service_name: Name of the web server service. - :type service_name: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -228,6 +214,13 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.node_hostname, + "services", + self.service_name, + ] web_service_state = access_from_nested_dict(state, self.location_in_state) # if webserver is no longer installed on the node, return 0 @@ -274,6 +267,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" + node_hostname: str = "" sticky: bool = True reward: float = 0.0 @@ -287,22 +281,6 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe sticky: bool = True reward: float = 0.0 - # def __init__(self, node_hostname: str, sticky: bool = True) -> None: - # """ - # Initialise the reward component. - - # :param node_hostname: Hostname of the node which has the web browser. - # :type node_hostname: str - # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - # the reward if there were any responses this timestep. - # :type sticky: bool - # """ - # self._node: str = node_hostname - # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - # self.sticky: bool = sticky - # self.reward: float = 0.0 - # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -317,7 +295,13 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe :return: Reward value :rtype: float """ - self.location_in_state: List[str] = ["network", "nodes", self.node_hostname, "applications", "WebBrowser"] + self.location_in_state = [ + "network", + "nodes", + self.node_hostname, + "applications", + "WebBrowser", + ] web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE: @@ -371,6 +355,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" + node_hostname: str = "" _node: str = node_hostname sticky: bool = True @@ -382,22 +367,6 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi node_hostname: str sticky: bool = True - # def __init__(self, node_hostname: str, sticky: bool = True) -> None: - # """ - # Initialise the reward component. - - # :param node_hostname: Hostname of the node where the database client sits. - # :type node_hostname: str - # :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - # the reward if there were any responses this timestep. - # :type sticky: bool - # """ - # self._node: str = node_hostname - # self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - # self.sticky: bool = sticky - # self.reward: float = 0.0 - # """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -449,6 +418,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" + agent_name: str class ConfigSchema(AbstractReward.ConfigSchema): @@ -456,19 +426,6 @@ class SharedReward(AbstractReward, identifier="SharedReward"): agent_name: str - # def __init__(self, agent_name: Optional[str] = None) -> None: - # """ - # Initialise the shared reward. - - # The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work - # correctly. - - # :param agent_name: The name whose reward is an input - # :type agent_name: Optional[str] - # """ - # # self.agent_name = agent_name - # """Agent whose reward to track.""" - def default_callback(agent_name: str) -> Never: """ Default callback to prevent calling this reward until it's properly initialised. @@ -508,6 +465,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" + action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 @@ -517,21 +475,6 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 - # def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: - # """ - # Initialise the reward. - - # Reward or penalise agents for doing nothing or taking actions. - - # :param action_penalty: Reward to give agents for taking any action except DONOTHING - # :type action_penalty: float - # :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action - # :type do_nothing_penalty: float - # """ - # super().__init__(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) - # self.action_penalty = action_penalty - # self.do_nothing_penalty = do_nothing_penalty - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied. diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py index 58f0fcc1..2ad1a322 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -11,7 +11,7 @@ from primaite.interface.request import RequestResponse class TestWebServer404PenaltySticky: def test_non_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=False) + reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=False) # no response codes yet, reward is 0 codes = [] @@ -38,7 +38,7 @@ class TestWebServer404PenaltySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=True) + reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=True) # no response codes yet, reward is 0 codes = [] @@ -67,7 +67,7 @@ class TestWebServer404PenaltySticky: class TestWebpageUnavailabilitySticky: def test_non_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=False) + reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=False) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -127,7 +127,7 @@ class TestWebpageUnavailabilitySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=True) + reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=True) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -188,7 +188,7 @@ class TestWebpageUnavailabilitySticky: class TestGreenAdminDatabaseUnreachableSticky: def test_non_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False) + reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=False) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -244,7 +244,7 @@ class TestGreenAdminDatabaseUnreachableSticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True) + reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=True) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] From 318f8926f0f7f0c85f3ce07bf551ca5ada88bbf9 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 22 Oct 2024 12:14:30 +0100 Subject: [PATCH 06/24] #2913: Fix remaining test errors. --- src/primaite/game/agent/rewards.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 1f870e83..2386bed5 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -272,7 +272,6 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe sticky: bool = True reward: float = 0.0 location_in_state: List[str] = [""] - _node: str = node_hostname class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" @@ -311,7 +310,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.node_hostname, "application", "WebBrowser", "execute", @@ -326,7 +325,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", - f"{self._node} was not reported in the simulation state. Returning 0.0", + f"{self.node_hostname} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 else: @@ -357,7 +356,6 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi """Penalises the agent when the green db clients fail to connect to the database.""" node_hostname: str = "" - _node: str = node_hostname sticky: bool = True reward: float = 0.0 @@ -385,7 +383,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.node_hostname, "application", "DatabaseClient", "execute", From 37bdbaf0d1ebb624b0b279d6a2ba8ba7cfd142f7 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 22 Oct 2024 16:15:04 +0100 Subject: [PATCH 07/24] #2913: Fix JSON breakage and old-style PORTS and PROTOCOL usage. --- .../Command-&-Control-E2E-Demonstration.ipynb | 10 +++--- .../create-simulation_demo.ipynb | 31 ++++++++++--------- .../network_simulator_demo.ipynb | 14 ++++----- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb index 6e6819fa..368dccf8 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb @@ -1781,9 +1781,11 @@ "outputs": [], "source": [ "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=PROTOCOL_LOOKUP[\"UDP\"], masquerade_port=PORT_LOOKUP[\"DNS\"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] @@ -1804,7 +1806,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -1818,7 +1820,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index f573f251..7ce8baaf 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -159,16 +159,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n", "from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.file_system.file_system import FileSystem\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "\n", "\n", "# no applications exist yet so we will create our own.\n", "class MSPaint(Application, identifier=\"MSPaint\"):\n", @@ -178,16 +179,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=PORT_LOOKUP[\"HTTP\"], protocol = PROTOCOL_LOOKUP[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -203,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -212,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -249,7 +250,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 2d5b4772..b09baa85 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "3", "metadata": { "tags": [] @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "4", "metadata": { "tags": [] @@ -532,12 +532,12 @@ }, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol["ICMP"],\n", + " protocol=PROTOCOL_LOOKUP[\"ICMP\"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" @@ -650,7 +650,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -664,7 +664,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, From c3f266e40116514cae6d842351524b42927e43d9 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 22 Oct 2024 16:26:57 +0100 Subject: [PATCH 08/24] #2913: Remove unneeded import and pre-commit changes. --- .../Command-&-Control-E2E-Demonstration.ipynb | 1 - .../create-simulation_demo.ipynb | 22 +++++++++---------- .../network_simulator_demo.ipynb | 4 ++-- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb index 368dccf8..d2972fa9 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb @@ -1780,7 +1780,6 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", "from primaite.utils.validation.port import PORT_LOOKUP\n", "\n", diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index 7ce8baaf..117ea019 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -188,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -204,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -213,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index b09baa85..4b620eee 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "3", "metadata": { "tags": [] @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "4", "metadata": { "tags": [] From 85216bec942a472f9007dba41e1be2a2c0846456 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 22 Oct 2024 16:48:30 +0100 Subject: [PATCH 09/24] #2913: Rename notebook to replace '&'. --- ...stration.ipynb => Command-and-Control-E2E-Demonstration.ipynb} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/primaite/notebooks/{Command-&-Control-E2E-Demonstration.ipynb => Command-and-Control-E2E-Demonstration.ipynb} (100%) diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb similarity index 100% rename from src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb rename to src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb From 6f6e4131b4cd39a108a93415cfdba20b1895b34a Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 29 Oct 2024 16:54:19 +0000 Subject: [PATCH 10/24] #2913: Handle case where server_ip_address is None --- src/primaite/simulator/system/applications/database_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index cd4b2a03..e030b306 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -308,6 +308,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """ if not self._can_perform_action(): return None + if self.server_ip_address is None: + return None connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None From 3c1bb2d546a051eca445b88d7e700a3dea3f1861 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 29 Oct 2024 16:57:11 +0000 Subject: [PATCH 11/24] #2913: Integration test fixes. --- src/primaite/simulator/system/core/session_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 75322e86..59390d68 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -16,7 +16,7 @@ from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: - from primaite.simulator.network.hardware.base import NetworkInterface + from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog From 9fd862763b4dfac4f4015871d705cddd4d072eb5 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 30 Oct 2024 11:11:07 +0000 Subject: [PATCH 12/24] #2913: Ensure optional software in config file is enabled. --- src/primaite/game/game.py | 2 +- src/primaite/simulator/system/applications/database_client.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c8fbac4e..7c2f49e7 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -376,7 +376,7 @@ class PrimaiteGame: if service_class is not None: _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(service_class) + new_node.software_manager.install(service_class, **service_cfg.get('options', {})) new_service = new_node.software_manager.software[service_class.__name__] # fixing duration for the service diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index e030b306..2b2be7b2 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -309,6 +309,9 @@ class DatabaseClient(Application, identifier="DatabaseClient"): if not self._can_perform_action(): return None if self.server_ip_address is None: + self.sys_log.warning( + f"{self.name}: Database server IP address not provided." + ) return None connection_request_id = str(uuid4()) From 97094aba795bfdeb5141f6d0080636635c74a377 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 30 Oct 2024 11:15:39 +0000 Subject: [PATCH 13/24] #2913: Pre-commit changes. --- src/primaite/game/game.py | 2 +- src/primaite/simulator/system/applications/database_client.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 7c2f49e7..7c5c93bc 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -376,7 +376,7 @@ class PrimaiteGame: if service_class is not None: _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(service_class, **service_cfg.get('options', {})) + new_node.software_manager.install(service_class, **service_cfg.get("options", {})) new_service = new_node.software_manager.software[service_class.__name__] # fixing duration for the service diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 2b2be7b2..2079194a 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -309,9 +309,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"): if not self._can_perform_action(): return None if self.server_ip_address is None: - self.sys_log.warning( - f"{self.name}: Database server IP address not provided." - ) + self.sys_log.warning(f"{self.name}: Database server IP address not provided.") return None connection_request_id = str(uuid4()) From 77219db0411f50ddf422dbb9eba8cfe0c5aee0e5 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 30 Oct 2024 16:32:49 +0000 Subject: [PATCH 14/24] #2913: Remove dns_server option from config files. --- tests/assets/configs/basic_switched_network.yaml | 2 -- tests/assets/configs/fix_duration_one_item.yaml | 6 ++---- tests/assets/configs/software_fix_duration.yaml | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index fed0f52d..799bb571 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -202,8 +202,6 @@ simulation: port_scan_p_of_success: 0.8 services: - type: DNSClient - options: - dns_server: 192.168.1.10 - type: DNSServer options: domain_mapping: diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fix_duration_one_item.yaml index bd0fb61f..4163bcfd 100644 --- a/tests/assets/configs/fix_duration_one_item.yaml +++ b/tests/assets/configs/fix_duration_one_item.yaml @@ -200,8 +200,6 @@ simulation: port_scan_p_of_success: 0.8 services: - type: DNSClient - options: - dns_server: 192.168.1.10 - type: DNSServer options: domain_mapping: @@ -232,8 +230,8 @@ simulation: server_password: arcd services: - type: DNSClient - options: - dns_server: 192.168.1.10 + # options: + # dns_server: 192.168.1.10 links: - endpoint_a_hostname: switch_1 diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml index 1a28258b..2b72e85f 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fix_duration.yaml @@ -209,7 +209,7 @@ simulation: services: - type: DNSClient options: - dns_server: 192.168.1.10 + # dns_server: 192.168.1.10 fix_duration: 3 - type: DNSServer options: @@ -250,8 +250,6 @@ simulation: server_password: arcd services: - type: DNSClient - options: - dns_server: 192.168.1.10 links: - endpoint_a_hostname: switch_1 From 7d977c809538ce236d38b21aec1c97787abfb29a Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 30 Oct 2024 16:33:14 +0000 Subject: [PATCH 15/24] #2913: Fix config path for test. --- .../integration_tests/extensions/test_extendable_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py index 5addcbd7..71a60194 100644 --- a/tests/integration_tests/extensions/test_extendable_config.py +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -5,6 +5,7 @@ from primaite.config.load import get_extended_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer +from tests import TEST_ASSETS_ROOT from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch @@ -13,11 +14,12 @@ from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch from tests.integration_tests.extensions.nodes.super_computer import SuperComputer from tests.integration_tests.extensions.services.extended_service import ExtendedService +CONFIG_PATH = TEST_ASSETS_ROOT / "configs/extended_config.yaml" + def test_extended_example_config(): """Test that the example config can be parsed properly.""" - config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml") - game = load_config(config_path) + game = load_config(CONFIG_PATH) network: Network = game.simulation.network assert len(network.nodes) == 10 # 10 nodes in example network From eb827f7e0a98b8c4dc9a94dc0e739a7c213d1d9f Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 31 Oct 2024 14:42:26 +0000 Subject: [PATCH 16/24] #2913: How-To guide initial commit. --- .../how_to_guides/extensible_rewards.rst | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 docs/source/how_to_guides/extensible_rewards.rst diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst new file mode 100644 index 00000000..3505d66c --- /dev/null +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -0,0 +1,72 @@ +.. only:: comment + + © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +.. _about: + +Extensible Rewards +****************** + +Changes to reward class structure. +================================== + +Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel). +Within the reward class is a ConfigSchema class responsible for ensuring config file data is in the +correct format. The `.from_config()` method is generally unchanged. + +Inheriting from `BaseModel` removes the need for an `__init__` method bu means that object +attributes need to be passed by keyword. + +.. code:: Python + +class AbstractReward(BaseModel): + """Base class for reward function components.""" + + class ConfigSchema(BaseModel, ABC): + """Config schema for AbstractReward.""" + + type: str + + _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate node adder {identifier}") + cls._registry[identifier] = cls + + @classmethod + def from_config(cls, config: Dict) -> "AbstractReward": + """Create a reward function component from a config dictionary. + + :param config: dict of options for the reward component's constructor + :type config: dict + :return: The reward component. + :rtype: AbstractReward + """ + if config["type"] not in cls._registry: + raise ValueError(f"Invalid reward type {config['type']}") + adder_class = cls._registry[config["type"]] + adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config)) + return cls + + @abstractmethod + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ + return 0.0 + + +Changes to YAML file. +===================== +.. code:: YAML + + There's no longer a need to provide a `dns_server` as an option in the simulation section + of the config file. \ No newline at end of file From 6b29362bf955ece37b39b31d34ee189231f9b5e5 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 31 Oct 2024 14:42:50 +0000 Subject: [PATCH 17/24] #2913: Tidy up config files. --- tests/assets/configs/fix_duration_one_item.yaml | 2 -- tests/assets/configs/software_fix_duration.yaml | 1 - 2 files changed, 3 deletions(-) diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fix_duration_one_item.yaml index 4163bcfd..c74590a1 100644 --- a/tests/assets/configs/fix_duration_one_item.yaml +++ b/tests/assets/configs/fix_duration_one_item.yaml @@ -230,8 +230,6 @@ simulation: server_password: arcd services: - type: DNSClient - # options: - # dns_server: 192.168.1.10 links: - endpoint_a_hostname: switch_1 diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml index 2b72e85f..6a705b37 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fix_duration.yaml @@ -209,7 +209,6 @@ simulation: services: - type: DNSClient options: - # dns_server: 192.168.1.10 fix_duration: 3 - type: DNSServer options: From b849ea6312c21366256fc297297017112dab1d9b Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 4 Nov 2024 17:41:43 +0000 Subject: [PATCH 18/24] #2913: Remove from_config() and refactor (WIP). --- src/primaite/game/agent/rewards.py | 167 +++++------------- .../game_layer/test_rewards.py | 11 +- 2 files changed, 48 insertions(+), 130 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 2386bed5..03764e4b 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -46,6 +46,14 @@ WhereType = Optional[Iterable[Union[str, int]]] class AbstractReward(BaseModel): """Base class for reward function components.""" + config: "AbstractReward.ConfigSchema" + + # def __init__(self, schema_name, **kwargs): + # super.__init__(self, **kwargs) + # # Create ConfigSchema class + # self.config_class = type(schema_name, (BaseModel, ABC), **kwargs) + # self.config = self.config_class() + class ConfigSchema(BaseModel, ABC): """Config schema for AbstractReward.""" @@ -56,7 +64,7 @@ class AbstractReward(BaseModel): def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) if identifier in cls._registry: - raise ValueError(f"Duplicate node adder {identifier}") + raise ValueError(f"Duplicate reward {identifier}") cls._registry[identifier] = cls @classmethod @@ -70,9 +78,10 @@ class AbstractReward(BaseModel): """ if config["type"] not in cls._registry: raise ValueError(f"Invalid reward type {config['type']}") - adder_class = cls._registry[config["type"]] - adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config)) - return cls + reward_class = cls._registry[config["type"]] + reward_config = reward_class.ConfigSchema(**config) + reward_class(config=reward_config) + return reward_class @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -103,30 +112,18 @@ class DummyReward(AbstractReward, identifier="DummyReward"): """ return 0.0 - @classmethod - def from_config(cls, config: dict) -> "DummyReward": - """Create a reward function component from a config dictionary. - - :param config: dict of options for the reward component's constructor. Should be empty. - :type config: dict - :return: The reward component. - :rtype: DummyReward - """ - return cls() - class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" - node_hostname: str - folder_name: str - file_name: str + config: "DatabaseFileIntegrity.ConfigSchema" location_in_state: List[str] = [""] reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for DatabaseFileIntegrity.""" + type: str = "DatabaseFileIntegrity" node_hostname: str folder_name: str file_name: str @@ -144,12 +141,12 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): self.location_in_state = [ "network", "nodes", - self.node_hostname, + self.config.node_hostname, "file_system", "folders", - self.folder_name, + self.config.folder_name, "files", - self.file_name, + self.config.file_name, ] database_file_state = access_from_nested_dict(state, self.location_in_state) @@ -168,38 +165,18 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): else: return 0 - @classmethod - def from_config(cls, config: Dict) -> "DatabaseFileIntegrity": - """Create a reward function component from a config dictionary. - - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: DatabaseFileIntegrity - """ - node_hostname = config.get("node_hostname") - folder_name = config.get("folder_name") - file_name = config.get("file_name") - if not (node_hostname and folder_name and file_name): - msg = f"{cls.__name__} could not be initialised with parameters {config}" - _LOGGER.error(msg) - raise ValueError(msg) - - return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name) - class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): """Reward function component which penalises the agent when the web server returns a 404 error.""" - node_hostname: str - service_name: str - sticky: bool = True + config: "WebServer404Penalty.ConfigSchema" location_in_state: List[str] = [""] reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebServer404Penalty.""" + type: str = "WebServer404Penalty" node_hostname: str service_name: str sticky: bool = True @@ -217,9 +194,9 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): self.location_in_state = [ "network", "nodes", - self.node_hostname, + self.config.node_hostname, "services", - self.service_name, + self.config.service_name, ] web_service_state = access_from_nested_dict(state, self.location_in_state) @@ -242,43 +219,20 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): return self.reward - @classmethod - def from_config(cls, config: Dict) -> "WebServer404Penalty": - """Create a reward function component from a config dictionary. - - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: WebServer404Penalty - """ - node_hostname = config.get("node_hostname") - service_name = config.get("service_name") - if not (node_hostname and service_name): - msg = ( - f"{cls.__name__} could not be initialised from config because node_name and service_ref were not " - "found in reward config." - ) - _LOGGER.warning(msg) - raise ValueError(msg) - sticky = config.get("sticky", True) - - return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) - class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): """Penalises the agent when the web browser fails to fetch a webpage.""" - node_hostname: str = "" - sticky: bool = True - reward: float = 0.0 - location_in_state: List[str] = [""] + config: "WebpageUnavailablePenalty.ConfigSchema" + reward: float = 0.0 # XXX: Private attribute? + location_in_state: List[str] = [""] # Calculate in __init__()? class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" + type: str = "WebpageUnavailablePenalty" node_hostname: str = "" sticky: bool = True - reward: float = 0.0 def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -297,7 +251,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe self.location_in_state = [ "network", "nodes", - self.node_hostname, + self.config.node_hostname, "applications", "WebBrowser", ] @@ -310,14 +264,14 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe request_attempted = last_action_response.request == [ "network", "node", - self.node_hostname, + self.config.node_hostname, "application", "WebBrowser", "execute", ] # skip calculating if sticky and no new codes, reusing last step value - if not request_attempted and self.sticky: + if not request_attempted and self.config.sticky: return self.reward if last_action_response.response.status != "success": @@ -339,29 +293,17 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe return self.reward - @classmethod - def from_config(cls, config: dict) -> AbstractReward: - """ - Build the reward component object from config. - - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): """Penalises the agent when the green db clients fail to connect to the database.""" - node_hostname: str = "" - sticky: bool = True + config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" reward: float = 0.0 class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" + type: str = "GreenAdminDatabaseUnreachablePenalty" node_hostname: str sticky: bool = True @@ -383,7 +325,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi request_attempted = last_action_response.request == [ "network", "node", - self.node_hostname, + self.config.node_hostname, "application", "DatabaseClient", "execute", @@ -392,7 +334,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi if request_attempted: # if agent makes request, always recalculate fresh value last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status} self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 - elif not self.sticky: # if no new request and not sticky, set reward to 0 + elif not self.config.sticky: # if no new request and not sticky, set reward to 0 last_action_response.reward_info = {"connection_attempt_status": "n/a"} self.reward = 0.0 else: # if no new request and sticky, reuse reward value from last step @@ -401,27 +343,16 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi return self.reward - @classmethod - def from_config(cls, config: Dict) -> AbstractReward: - """ - Build the reward component object from config. - - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - class SharedReward(AbstractReward, identifier="SharedReward"): """Adds another agent's reward to the overall reward.""" - agent_name: str + config: "SharedReward.ConfigSchema" class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for SharedReward.""" + type: str = "SharedReward" agent_name: str def default_callback(agent_name: str) -> Never: @@ -447,29 +378,18 @@ class SharedReward(AbstractReward, identifier="SharedReward"): :return: Reward value :rtype: float """ - return self.callback(self.agent_name) - - @classmethod - def from_config(cls, config: Dict) -> "SharedReward": - """ - Build the SharedReward object from config. - - :param config: Configuration dictionary - :type config: Dict - """ - agent_name = config.get("agent_name") - return cls(agent_name=agent_name) + return self.callback(self.config.agent_name) class ActionPenalty(AbstractReward, identifier="ActionPenalty"): """Apply a negative reward when taking any action except DONOTHING.""" - action_penalty: float = -1.0 - do_nothing_penalty: float = 0.0 + config: "ActionPenalty.ConfigSchema" class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for ActionPenalty.""" + type: str = "ActionPenalty" action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 @@ -484,16 +404,9 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): :rtype: float """ if last_action_response.action == "DONOTHING": - return self.do_nothing_penalty + return self.config.do_nothing_penalty else: - return self.action_penalty - - @classmethod - def from_config(cls, config: Dict) -> "ActionPenalty": - """Build the ActionPenalty object from config.""" - action_penalty = config.get("action_penalty", -1.0) - do_nothing_penalty = config.get("do_nothing_penalty", 0.0) - return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) + return self.config.action_penalty class RewardFunction: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index bf707feb..d4236d1b 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -23,7 +23,9 @@ def test_WebpageUnavailablePenalty(game_and_agent): # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent agent: ControlledAgent - comp = WebpageUnavailablePenalty(node_hostname="client_1") + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = WebpageUnavailablePenalty(config=schema) + client_1 = game.simulation.network.get_node_by_hostname("client_1") browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() @@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent): ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) - comp = GreenAdminDatabaseUnreachablePenalty(node_hostname="client_1") + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = GreenAdminDatabaseUnreachablePenalty(config=schema) request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] response = game.simulation.apply_request(request) @@ -147,7 +150,9 @@ def test_action_penalty(): """Test that the action penalty is correctly applied when agent performs any action""" # Create an ActionPenalty Reward - Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + # Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + Penalty = ActionPenalty(schema) # Assert that penalty is applied if action isn't DONOTHING reward_value = Penalty.calculate( From 370bcfc47682049cafe238f97e089b5d8b5ea013 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 6 Nov 2024 11:35:06 +0000 Subject: [PATCH 19/24] #2913: Make rewards work with config file. --- src/primaite/game/agent/rewards.py | 51 ++++++++----------- src/primaite/game/game.py | 2 +- .../game_layer/test_rewards.py | 2 +- 3 files changed, 23 insertions(+), 32 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 03764e4b..029597a0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -79,9 +79,8 @@ class AbstractReward(BaseModel): if config["type"] not in cls._registry: raise ValueError(f"Invalid reward type {config['type']}") reward_class = cls._registry[config["type"]] - reward_config = reward_class.ConfigSchema(**config) - reward_class(config=reward_config) - return reward_class + reward_obj = reward_class(config=reward_class.ConfigSchema(**config)) + return reward_obj @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -97,7 +96,7 @@ class AbstractReward(BaseModel): return 0.0 -class DummyReward(AbstractReward, identifier="DummyReward"): +class DummyReward(AbstractReward, identifier="DUMMY"): """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -113,7 +112,7 @@ class DummyReward(AbstractReward, identifier="DummyReward"): return 0.0 -class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): +class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" config: "DatabaseFileIntegrity.ConfigSchema" @@ -123,7 +122,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for DatabaseFileIntegrity.""" - type: str = "DatabaseFileIntegrity" + type: str = "DATABASE_FILE_INTEGRITY" node_hostname: str folder_name: str file_name: str @@ -166,7 +165,7 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"): return 0 -class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): +class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): """Reward function component which penalises the agent when the web server returns a 404 error.""" config: "WebServer404Penalty.ConfigSchema" @@ -176,7 +175,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebServer404Penalty.""" - type: str = "WebServer404Penalty" + type: str = "WEB_SERVER_404_PENALTY" node_hostname: str service_name: str sticky: bool = True @@ -220,7 +219,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"): return self.reward -class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"): +class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"): """Penalises the agent when the web browser fails to fetch a webpage.""" config: "WebpageUnavailablePenalty.ConfigSchema" @@ -230,7 +229,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for WebpageUnavailablePenalty.""" - type: str = "WebpageUnavailablePenalty" + type: str = "WEBPAGE_UNAVAILABLE_PENALTY" node_hostname: str = "" sticky: bool = True @@ -294,7 +293,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe return self.reward -class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"): """Penalises the agent when the green db clients fail to connect to the database.""" config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" @@ -303,7 +302,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi class ConfigSchema(AbstractReward.ConfigSchema): """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" - type: str = "GreenAdminDatabaseUnreachablePenalty" + type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY" node_hostname: str sticky: bool = True @@ -344,7 +343,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi return self.reward -class SharedReward(AbstractReward, identifier="SharedReward"): +class SharedReward(AbstractReward, identifier="SHARED_REWARD"): """Adds another agent's reward to the overall reward.""" config: "SharedReward.ConfigSchema" @@ -352,7 +351,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"): class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for SharedReward.""" - type: str = "SharedReward" + type: str = "SHARED_REWARD" agent_name: str def default_callback(agent_name: str) -> Never: @@ -381,7 +380,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"): return self.callback(self.config.agent_name) -class ActionPenalty(AbstractReward, identifier="ActionPenalty"): +class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"): """Apply a negative reward when taking any action except DONOTHING.""" config: "ActionPenalty.ConfigSchema" @@ -389,7 +388,7 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): class ConfigSchema(AbstractReward.ConfigSchema): """Config schema for ActionPenalty.""" - type: str = "ActionPenalty" + type: str = "ACTION_PENALTY" action_penalty: float = -1.0 do_nothing_penalty: float = 0.0 @@ -412,17 +411,6 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"): class RewardFunction: """Manages the reward function for the agent.""" - rew_class_identifiers: Dict[str, Type[AbstractReward]] = { - "DUMMY": DummyReward, - "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, - "WEB_SERVER_404_PENALTY": WebServer404Penalty, - "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, - "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, - "SHARED_REWARD": SharedReward, - "ACTION_PENALTY": ActionPenalty, - } - """List of reward class identifiers.""" - def __init__(self): """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] @@ -457,7 +445,7 @@ class RewardFunction: @classmethod def from_config(cls, config: Dict) -> "RewardFunction": - """Create a reward function from a config dictionary. + """Create a reward function from a config dictionary and its related reward class. :param config: dict of options for the reward manager's constructor :type config: Dict @@ -468,8 +456,11 @@ class RewardFunction: for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] + # XXX: If options key is missing add key then add type key. + if "options" not in rew_component_cfg: + rew_component_cfg["options"] = {} + rew_component_cfg["options"]["type"] = rew_type weight = rew_component_cfg.get("weight", 1.0) - rew_class = cls.rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {})) + rew_instance = AbstractReward.from_config(rew_component_cfg["options"]) new.register_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 7c5c93bc..51d0306c 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -629,7 +629,7 @@ class PrimaiteGame: for comp, weight in agent.reward_function.reward_components: if isinstance(comp, SharedReward): comp: SharedReward - graph[name].add(comp.agent_name) + graph[name].add(comp.config.agent_name) # while constructing the graph, we might as well set up the reward sharing itself. comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index d4236d1b..6544c82d 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -25,7 +25,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): agent: ControlledAgent schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) comp = WebpageUnavailablePenalty(config=schema) - + client_1 = game.simulation.network.get_node_by_hostname("client_1") browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() From 4c2ef6ea2a873d5d7bb965458db00bff21ef1dca Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 6 Nov 2024 14:52:22 +0000 Subject: [PATCH 20/24] #2913: Updated tests --- .../integration_tests/game_layer/test_rewards.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 6544c82d..742b2d35 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -18,7 +18,7 @@ from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent -def test_WebpageUnavailablePenalty(game_and_agent): +def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for failing to fetch a website.""" # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent @@ -55,7 +55,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): assert agent.reward_function.current_reward == -0.7 -def test_uc2_rewards(game_and_agent): +def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that the reward component correctly applies a penalty when the selected client cannot reach the database.""" game, agent = game_and_agent agent: ControlledAgent @@ -142,8 +142,8 @@ def test_action_penalty_loads_from_config(): act_penalty_obj = comp[0] if act_penalty_obj is None: pytest.fail("Action penalty reward component was not added to the agent from config.") - assert act_penalty_obj.action_penalty == -0.75 - assert act_penalty_obj.do_nothing_penalty == 0.125 + assert act_penalty_obj.config.action_penalty == -0.75 + assert act_penalty_obj.config.do_nothing_penalty == 0.125 def test_action_penalty(): @@ -152,7 +152,7 @@ def test_action_penalty(): # Create an ActionPenalty Reward schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) # Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) - Penalty = ActionPenalty(schema) + Penalty = ActionPenalty(config=schema) # Assert that penalty is applied if action isn't DONOTHING reward_value = Penalty.calculate( @@ -183,11 +183,12 @@ def test_action_penalty(): assert reward_value == 0.125 -def test_action_penalty_e2e(game_and_agent): +def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for doing actions to fetch a website.""" game, agent = game_and_agent agent: ControlledAgent - comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + comp = ActionPenalty(config=schema) agent.reward_function.register_component(comp, 1.0) From 9d6536fa6aee13858e39957d215061fecce5966f Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 6 Nov 2024 15:08:38 +0000 Subject: [PATCH 21/24] #2913: Pre-commit fix --- docs/source/how_to_guides/extensible_rewards.rst | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst index 3505d66c..2551eee0 100644 --- a/docs/source/how_to_guides/extensible_rewards.rst +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -11,10 +11,13 @@ Changes to reward class structure. ================================== Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel). -Within the reward class is a ConfigSchema class responsible for ensuring config file data is in the -correct format. The `.from_config()` method is generally unchanged. +Within the reward class there is a ConfigSchema class responsible for ensuring config file data is +in the correct format. The `.from_config()` method is generally unchanged but should initialise the +attributes edfined in the ConfigSchema. +Each class requires an identifier string which is used by the ConfigSchema class to verify that it +hasn't previously been added to the registry. -Inheriting from `BaseModel` removes the need for an `__init__` method bu means that object +Inheriting from `BaseModel` removes the need for an `__init__` method but means that object attributes need to be passed by keyword. .. code:: Python @@ -69,4 +72,4 @@ Changes to YAML file. .. code:: YAML There's no longer a need to provide a `dns_server` as an option in the simulation section - of the config file. \ No newline at end of file + of the config file. From e0b885cc79497bf21a010d3c7313d0fbaba0728c Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 7 Nov 2024 13:08:44 +0000 Subject: [PATCH 22/24] #2913: Changes to update test_sticky_rewards.py --- src/primaite/game/agent/rewards.py | 4 +-- .../_game/_agent/test_sticky_rewards.py | 33 +++++++++++++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 029597a0..05bca033 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -211,7 +211,7 @@ class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average - elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 + elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0 self.reward = 0.0 else: # skip calculating if sticky and no new codes. instead, reuse last step's value pass @@ -278,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_ elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", - f"{self.node_hostname} was not reported in the simulation state. Returning 0.0", + f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 else: diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py index 2ad1a322..c758291f 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse class TestWebServer404PenaltySticky: def test_non_sticky(self): - reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=False) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=False, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=True) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=True, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -67,7 +77,8 @@ class TestWebServer404PenaltySticky: class TestWebpageUnavailabilitySticky: def test_non_sticky(self): - reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=False) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -127,7 +138,8 @@ class TestWebpageUnavailabilitySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=True) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -188,7 +200,11 @@ class TestWebpageUnavailabilitySticky: class TestGreenAdminDatabaseUnreachableSticky: def test_non_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=False) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=False, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -214,7 +230,6 @@ class TestGreenAdminDatabaseUnreachableSticky: # agent did nothing, because reward is not sticky, it goes back to 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] response = RequestResponse(status="success", data={}) - browser_history = [] state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response @@ -244,7 +259,11 @@ class TestGreenAdminDatabaseUnreachableSticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=True) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=True, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] From 02d29f7fb9bb93ebe32c13503ee3e56dfe545369 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 7 Nov 2024 16:35:39 +0000 Subject: [PATCH 23/24] #2913: Updates to How-To guide --- .../how_to_guides/extensible_rewards.rst | 60 +++++++------------ 1 file changed, 21 insertions(+), 39 deletions(-) diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst index 2551eee0..4dd24110 100644 --- a/docs/source/how_to_guides/extensible_rewards.rst +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -6,65 +6,47 @@ Extensible Rewards ****************** +Extensible Rewards differ from the previous reward mechanism used in PrimAITE v3.x as new reward +types can be added without requiring a change to the RewardFunction class in rewards.py (PrimAITE +core repository). Changes to reward class structure. ================================== Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel). -Within the reward class there is a ConfigSchema class responsible for ensuring config file data is -in the correct format. The `.from_config()` method is generally unchanged but should initialise the -attributes edfined in the ConfigSchema. +Within the reward class there is a ConfigSchema class responsible for ensuring the config file data +is in the correct format. This also means there is little (if no) requirement for and `__init__` +method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`. Each class requires an identifier string which is used by the ConfigSchema class to verify that it hasn't previously been added to the registry. Inheriting from `BaseModel` removes the need for an `__init__` method but means that object attributes need to be passed by keyword. -.. code:: Python +To add a new reward class follow the example below. Note that the type attribute in the +`ConfigSchema` class should match the type used in the config file to define the reward. -class AbstractReward(BaseModel): - """Base class for reward function components.""" +.. code-block:: Python - class ConfigSchema(BaseModel, ABC): - """Config schema for AbstractReward.""" +class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): + """Reward function component which rewards the agent for maintaining the integrity of a database file.""" - type: str + config: "DatabaseFileIntegrity.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 - _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" - def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - if identifier in cls._registry: - raise ValueError(f"Duplicate node adder {identifier}") - cls._registry[identifier] = cls + type: str = "DATABASE_FILE_INTEGRITY" + node_hostname: str + folder_name: str + file_name: str - @classmethod - def from_config(cls, config: Dict) -> "AbstractReward": - """Create a reward function component from a config dictionary. - - :param config: dict of options for the reward component's constructor - :type config: dict - :return: The reward component. - :rtype: AbstractReward - """ - if config["type"] not in cls._registry: - raise ValueError(f"Invalid reward type {config['type']}") - adder_class = cls._registry[config["type"]] - adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config)) - return cls - - @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. + pass - :param state: Current simulation state - :type state: Dict - :param last_action_response: Current agent history state - :type last_action_response: AgentHistoryItem state - :return: Reward value - :rtype: float - """ - return 0.0 Changes to YAML file. From c9752f0dc5e94a7dd4f0b58004e03096648888d1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 3 Jan 2025 11:22:17 +0000 Subject: [PATCH 24/24] #2913 - minor comment cleanup --- docs/source/how_to_guides/extensible_rewards.rst | 2 +- src/primaite/game/agent/rewards.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst index 4dd24110..a01b9d8f 100644 --- a/docs/source/how_to_guides/extensible_rewards.rst +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -1,6 +1,6 @@ .. only:: comment - © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK .. _about: diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index c5850d6e..a4c7c546 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -223,7 +223,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_ """Penalises the agent when the web browser fails to fetch a webpage.""" config: "WebpageUnavailablePenalty.ConfigSchema" - reward: float = 0.0 # XXX: Private attribute? + reward: float = 0.0 location_in_state: List[str] = [""] # Calculate in __init__()? class ConfigSchema(AbstractReward.ConfigSchema):