From 05f9751fa81dd2f9e822b1710864cbd2f42d2875 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 10:17:39 +0100 Subject: [PATCH 1/5] #2736 - implement instantaneous rewards --- src/primaite/game/agent/rewards.py | 137 ++++++++++++------ .../system/applications/database_client.py | 11 -- .../system/services/web_server/web_server.py | 20 ++- 3 files changed, 104 insertions(+), 64 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index c959ee5b..00374791 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -151,14 +151,20 @@ class DatabaseFileIntegrity(AbstractReward): class WebServer404Penalty(AbstractReward): """Reward function component which penalises the agent when the web server returns a 404 error.""" - def __init__(self, node_hostname: str, service_name: str) -> None: + def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: """Initialise the reward component. :param node_hostname: Hostname of the node which contains the web server service. :type node_hostname: str :param service_name: Name of the web server service. :type service_name: str + :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + the reward if there were any responses this timestep. + :type sticky: bool """ + self.sticky: bool = sticky + self.reward: float = 0.0 + """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -168,16 +174,21 @@ class WebServer404Penalty(AbstractReward): :type state: Dict """ web_service_state = access_from_nested_dict(state, self.location_in_state) + + # if webserver is no longer installed on the node, return 0 if web_service_state is NOT_PRESENT_IN_STATE: return 0.0 - most_recent_return_code = web_service_state["last_response_status_code"] - # TODO: reward needs to use the current web state. Observation should return web state at the time of last scan. - if most_recent_return_code == 200: - return 1.0 - elif most_recent_return_code == 404: - return -1.0 - else: - return 0.0 + + codes = web_service_state.get("response_codes_this_timestep") + if codes or not self.sticky: # skip calculating if sticky and no new codes. Insted, reuse last step's value. + + def status2rew(status: int) -> int: + """Map status codes to reward values.""" + return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 + + self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average + + return self.reward @classmethod def from_config(cls, config: Dict) -> "WebServer404Penalty": @@ -197,23 +208,29 @@ class WebServer404Penalty(AbstractReward): ) _LOGGER.warning(msg) raise ValueError(msg) + sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, service_name=service_name) + return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) class WebpageUnavailablePenalty(AbstractReward): """Penalises the agent when the web browser fails to fetch a webpage.""" - def __init__(self, node_hostname: str) -> None: + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. :param node_hostname: Hostname of the node which has the web browser. :type node_hostname: str + :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + the reward if there were any responses this timestep. + :type sticky: bool """ self._node: str = node_hostname self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - self._last_request_failed: bool = False + self.sticky: bool = sticky + self.reward: float = 0.0 + """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -223,31 +240,46 @@ class WebpageUnavailablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last webpage had a 200 status code, because there has been an unsuccessful request since. """ - if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]: - self._last_request_failed = last_action_response.response.status != "success" - - # if agent couldn't even get as far as sending the request (because for example the node was off), then - # apply a penalty - if self._last_request_failed: - return -1.0 - - # If the last request did actually go through, then check if the webpage also loaded web_browser_state = access_from_nested_dict(state, self.location_in_state) - if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: + + if web_browser_state is NOT_PRESENT_IN_STATE: + self.reward = 0.0 + + # check if the most recent action was to request the webpage + request_attempted = last_action_response.request == [ + "network", + "node", + self._node, + "application", + "WebBrowser", + "execute", + ] + + if ( + not request_attempted and self.sticky + ): # skip calculating if sticky and no new codes, reusing last step value + return self.reward + + if last_action_response.response.status != "success": + self.reward = -1.0 + # + elif web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", f"{self._node} was not reported in the simulation state. Returning 0.0", ) - return 0.0 # 0 if the web browser cannot be found - if not web_browser_state["history"]: - return 0.0 # 0 if no requests have been attempted yet + self.reward = 0.0 + elif not web_browser_state["history"]: + self.reward = 0.0 # 0 if no requests have been attempted yet outcome = web_browser_state["history"][-1]["outcome"] if outcome == "PENDING": - return 0.0 # 0 if a request was attempted but not yet resolved + self.reward = 0.0 # 0 if a request was attempted but not yet resolved elif outcome == 200: - return 1.0 # 1 for successful request + self.reward = 1.0 # 1 for successful request else: # includes failure codes and SERVER_UNREACHABLE - return -1.0 # -1 for failure + self.reward = -1.0 # -1 for failure + + return self.reward @classmethod def from_config(cls, config: dict) -> AbstractReward: @@ -258,22 +290,28 @@ class WebpageUnavailablePenalty(AbstractReward): :type config: Dict """ node_hostname = config.get("node_hostname") - return cls(node_hostname=node_hostname) + sticky = config.get("sticky", True) + return cls(node_hostname=node_hostname, sticky=sticky) class GreenAdminDatabaseUnreachablePenalty(AbstractReward): """Penalises the agent when the green db clients fail to connect to the database.""" - def __init__(self, node_hostname: str) -> None: + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. :param node_hostname: Hostname of the node where the database client sits. :type node_hostname: str + :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + the reward if there were any responses this timestep. + :type sticky: bool """ self._node: str = node_hostname self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - self._last_request_failed: bool = False + self.sticky: bool = sticky + self.reward: float = 0.0 + """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -284,25 +322,29 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): request returned was able to connect to the database server, because there has been an unsuccessful request since. """ - if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: - self._last_request_failed = last_action_response.response.status != "success" - - # if agent couldn't even get as far as sending the request (because for example the node was off), then - # apply a penalty - if self._last_request_failed: - return -1.0 + db_state = access_from_nested_dict(state, self.location_in_state) # If the last request was actually sent, then check if the connection was established. - db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") - return 0.0 - last_connection_successful = db_state["last_connection_successful"] - if last_connection_successful is False: - return -1.0 - elif last_connection_successful is True: - return 1.0 - return 0.0 + self.reward = 0.0 + + request_attempted = last_action_response.request == [ + "network", + "node", + self._node, + "application", + "DatabaseClient", + "execute", + ] + + if ( + not request_attempted and self.sticky + ): # skip calculating if sticky and no new codes, reusing last step value + return self.reward + + self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 + return self.reward @classmethod def from_config(cls, config: Dict) -> AbstractReward: @@ -313,7 +355,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): :type config: Dict """ node_hostname = config.get("node_hostname") - return cls(node_hostname=node_hostname) + sticky = config.get("sticky", True) + return cls(node_hostname=node_hostname, sticky=sticky) class SharedReward(AbstractReward): diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index e6cfa343..933afadf 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -73,7 +73,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): server_ip_address: Optional[IPv4Address] = None server_password: Optional[str] = None - _last_connection_successful: Optional[bool] = None _query_success_tracker: Dict[str, bool] = {} """Keep track of connections that were established or verified during this step. Used for rewards.""" last_query_response: Optional[Dict] = None @@ -135,8 +134,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): :return: A dictionary representing the current state. """ state = super().describe_state() - # list of connections that were established or verified during this step. - state["last_connection_successful"] = self._last_connection_successful return state def show(self, markdown: bool = False): @@ -226,13 +223,11 @@ class DatabaseClient(Application, identifier="DatabaseClient"): f"Using connection id {database_client_connection}" ) self.connected = True - self._last_connection_successful = True return database_client_connection else: self.sys_log.info( f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined" ) - self._last_connection_successful = False return None else: self.sys_log.info( @@ -357,10 +352,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): success = self._query_success_tracker.get(query_id) if success: self.sys_log.info(f"{self.name}: Query successful {sql}") - self._last_connection_successful = True return True self.sys_log.error(f"{self.name}: Unable to run query {sql}") - self._last_connection_successful = False return False else: software_manager: SoftwareManager = self.software_manager @@ -390,9 +383,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): if not self.native_connection: return False - # reset last query response - self.last_query_response = None - uuid = str(uuid4()) self._query_success_tracker[uuid] = False return self.native_connection.query(sql) @@ -416,7 +406,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): connection_id=connection_id, connection_request_id=payload["connection_request_id"] ) elif payload["type"] == "sql": - self.last_query_response = payload query_id = payload.get("uuid") status_code = payload.get("status_code") self._query_success_tracker[query_id] = status_code == 200 diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 6f6fa335..4fc64e1f 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urlparse from primaite import getLogger @@ -22,7 +22,7 @@ _LOGGER = getLogger(__name__) class WebServer(Service): """Class used to represent a Web Server Service in simulation.""" - last_response_status_code: Optional[HttpStatusCode] = None + response_codes_this_timestep: List[HttpStatusCode] = [] def describe_state(self) -> Dict: """ @@ -34,11 +34,19 @@ class WebServer(Service): :rtype: Dict """ state = super().describe_state() - state["last_response_status_code"] = ( - self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None - ) + state["response_codes_this_timestep"] = [code.value for code in self.response_codes_this_timestep] return state + def pre_timestep(self, timestep: int) -> None: + """ + Logic to execute at the start of the timestep - clear the observation-related attributes. + + :param timestep: the current timestep in the episode. + :type timestep: int + """ + self.response_codes_this_timestep = [] + return super().pre_timestep(timestep) + def __init__(self, **kwargs): kwargs["name"] = "WebServer" kwargs["protocol"] = IPProtocol.TCP @@ -89,7 +97,7 @@ class WebServer(Service): self.send(payload=response, session_id=session_id) # return true if response is OK - self.last_response_status_code = response.status_code + self.response_codes_this_timestep.append(response.status_code) return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: From f344d292dbfe43482e1e021a571aece2105e948b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 13:59:35 +0100 Subject: [PATCH 2/5] #2736 - Fix up broken reward tests --- src/primaite/game/agent/rewards.py | 32 ++++++------- .../system/applications/database_client.py | 2 - .../game_layer/test_rewards.py | 46 ++++++++----------- .../test_data_manipulation_bot_and_server.py | 2 - .../test_ransomware_script.py | 11 +++-- 5 files changed, 39 insertions(+), 54 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 00374791..8ac3956c 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -255,29 +255,26 @@ class WebpageUnavailablePenalty(AbstractReward): "execute", ] - if ( - not request_attempted and self.sticky - ): # skip calculating if sticky and no new codes, reusing last step value + # skip calculating if sticky and no new codes, reusing last step value + if not request_attempted and self.sticky: return self.reward if last_action_response.response.status != "success": self.reward = -1.0 - # - elif web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: + elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", f"{self._node} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 - elif not web_browser_state["history"]: - self.reward = 0.0 # 0 if no requests have been attempted yet - outcome = web_browser_state["history"][-1]["outcome"] - if outcome == "PENDING": - self.reward = 0.0 # 0 if a request was attempted but not yet resolved - elif outcome == 200: - self.reward = 1.0 # 1 for successful request - else: # includes failure codes and SERVER_UNREACHABLE - self.reward = -1.0 # -1 for failure + else: + outcome = web_browser_state["history"][-1]["outcome"] + if outcome == "PENDING": + self.reward = 0.0 # 0 if a request was attempted but not yet resolved + elif outcome == 200: + self.reward = 1.0 # 1 for successful request + else: # includes failure codes and SERVER_UNREACHABLE + self.reward = -1.0 # -1 for failure return self.reward @@ -325,7 +322,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): db_state = access_from_nested_dict(state, self.location_in_state) # If the last request was actually sent, then check if the connection was established. - if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: + if db_state is NOT_PRESENT_IN_STATE: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") self.reward = 0.0 @@ -338,9 +335,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): "execute", ] - if ( - not request_attempted and self.sticky - ): # skip calculating if sticky and no new codes, reusing last step value + # skip calculating if sticky and no new codes, reusing last step value + if not request_attempted and self.sticky: return self.reward self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 933afadf..0a626c00 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -75,8 +75,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): server_password: Optional[str] = None _query_success_tracker: Dict[str, bool] = {} """Keep track of connections that were established or verified during this step. Used for rewards.""" - last_query_response: Optional[Dict] = None - """Keep track of the latest query response. Used to determine rewards.""" _server_connection_id: Optional[str] = None """Connection ID to the Database Server.""" client_connections: Dict[str, DatabaseClientConnection] = {} diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 2bf551c8..83b04832 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.network.router import ACLAction, from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -19,32 +20,30 @@ from tests.conftest import ControlledAgent def test_WebpageUnavailablePenalty(game_and_agent): """Test that we get the right reward for failing to fetch a website.""" + # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent agent: ControlledAgent comp = WebpageUnavailablePenalty(node_hostname="client_1") - - agent.reward_function.register_component(comp, 0.7) - action = ("DONOTHING", {}) - agent.store_action(action) - game.step() - - # client 1 has not attempted to fetch webpage yet! - assert agent.reward_function.current_reward == 0.0 - client_1 = game.simulation.network.get_node_by_hostname("client_1") - browser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() browser.target_url = "http://www.example.com" - assert browser.get_webpage() - action = ("DONOTHING", {}) - agent.store_action(action) + agent.reward_function.register_component(comp, 0.7) + + # Check that before trying to fetch the webpage, the reward is 0.0 + agent.store_action(("DONOTHING", {})) + game.step() + assert agent.reward_function.current_reward == 0.0 + + # Check that successfully fetching the webpage yields a reward of 0.7 + agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == 0.7 + # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP) - assert not browser.get_webpage() - agent.store_action(action) + agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == -0.7 @@ -70,32 +69,25 @@ def test_uc2_rewards(game_and_agent): comp = GreenAdminDatabaseUnreachablePenalty("client_1") - response = db_client.apply_request( - [ - "execute", - ] - ) + request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] + response = game.simulation.apply_request(request) state = game.get_sim_state() reward_value = comp.calculate( state, last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response ), ) assert reward_value == 1.0 router.acl.remove_rule(position=2) - db_client.apply_request( - [ - "execute", - ] - ) + response = game.simulation.apply_request(request) state = game.get_sim_state() reward_value = comp.calculate( state, last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response ), ) assert reward_value == -1.0 diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index a01cffbe..2e87578d 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -146,7 +146,6 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_ assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD assert green_db_connection.query("SELECT") - assert green_db_client.last_query_response.get("status_code") == 200 data_manipulation_bot.port_scan_p_of_success = 1 data_manipulation_bot.data_manipulation_p_of_success = 1 @@ -155,4 +154,3 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_ assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED assert green_db_connection.query("SELECT") is False - assert green_db_client.last_query_response.get("status_code") != 200 diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index 2e3a0b1c..97abafb5 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -103,7 +103,7 @@ def test_ransomware_script_attack(ransomware_script_and_db_server): def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client): - """Test to see show that the database service still operate""" + """Test to show that the database service still operates after corruption""" network: Network = ransomware_script_db_server_green_client client_1: Computer = network.get_node_by_hostname("client_1") @@ -111,17 +111,18 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_ client_2: Computer = network.get_node_by_hostname("client_2") green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + green_db_client.connect() green_db_client_connection: DatabaseClientConnection = green_db_client.get_new_connection() server: Server = network.get_node_by_hostname("server_1") db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD - assert green_db_client_connection.query("SELECT") - assert green_db_client.last_query_response.get("status_code") == 200 + assert green_db_client.query("SELECT") is True ransomware_script_application.attack() + network.apply_timestep(0) + assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT - assert green_db_client_connection.query("SELECT") is True - assert green_db_client.last_query_response.get("status_code") == 200 + assert green_db_client.query("SELECT") is True # Still operates but now the data field of response is empty From 538e853f26591113cf1844d28bb7c6db7677fd0d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 15:32:25 +0100 Subject: [PATCH 3/5] #2736 - Add sticky reward tests and fix sticky reward behaviour --- src/primaite/game/agent/rewards.py | 23 +- .../_game/_agent/test_sticky_rewards.py | 299 ++++++++++++++++++ 2 files changed, 310 insertions(+), 12 deletions(-) create mode 100644 tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 8ac3956c..321df098 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -180,13 +180,17 @@ class WebServer404Penalty(AbstractReward): return 0.0 codes = web_service_state.get("response_codes_this_timestep") - if codes or not self.sticky: # skip calculating if sticky and no new codes. Insted, reuse last step's value. + if codes: def status2rew(status: int) -> int: """Map status codes to reward values.""" return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average + elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 + self.reward = 0 + else: # skip calculating if sticky and no new codes. insted, reuse last step's value + pass return self.reward @@ -319,13 +323,6 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): request returned was able to connect to the database server, because there has been an unsuccessful request since. """ - db_state = access_from_nested_dict(state, self.location_in_state) - - # If the last request was actually sent, then check if the connection was established. - if db_state is NOT_PRESENT_IN_STATE: - _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") - self.reward = 0.0 - request_attempted = last_action_response.request == [ "network", "node", @@ -335,11 +332,13 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): "execute", ] - # skip calculating if sticky and no new codes, reusing last step value - if not request_attempted and self.sticky: - return self.reward + if request_attempted: # if agent makes request, always recalculate fresh value + self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 + elif not self.sticky: # if no new request and not sticky, set reward to 0 + self.reward = 0.0 + else: # if no new request and sticky, reuse reward value from last step + pass - self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 return self.reward @classmethod diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py new file mode 100644 index 00000000..58f0fcc1 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -0,0 +1,299 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +from primaite.game.agent.interface import AgentHistoryItem +from primaite.game.agent.rewards import ( + GreenAdminDatabaseUnreachablePenalty, + WebpageUnavailablePenalty, + WebServer404Penalty, +) +from primaite.interface.request import RequestResponse + + +class TestWebServer404PenaltySticky: + def test_non_sticky(self): + reward = WebServer404Penalty("computer", "WebService", sticky=False) + + # no response codes yet, reward is 0 + codes = [] + state = { + "network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}} + } + last_action_response = None + assert reward.calculate(state, last_action_response) == 0 + + # update codes (by reference), 200 response code is now present + codes.append(200) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # update codes (by reference), to make it empty again, reward goes back to 0 + codes.pop() + assert reward.calculate(state, last_action_response) == 0.0 + + # update codes (by reference), 404 response code is now present, reward = -1.0 + codes.append(404) + assert reward.calculate(state, last_action_response) == -1.0 + + # don't update codes, it still has just a 404, check the reward is -1.0 again + assert reward.calculate(state, last_action_response) == -1.0 + + def test_sticky(self): + reward = WebServer404Penalty("computer", "WebService", sticky=True) + + # no response codes yet, reward is 0 + codes = [] + state = { + "network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}} + } + last_action_response = None + assert reward.calculate(state, last_action_response) == 0 + + # update codes (by reference), 200 response code is now present + codes.append(200) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # update codes (by reference), to make it empty again, reward remains at 1.0 because it's sticky + codes.pop() + assert reward.calculate(state, last_action_response) == 1.0 + + # update codes (by reference), 404 response code is now present, reward = -1.0 + codes.append(404) + assert reward.calculate(state, last_action_response) == -1.0 + + # don't update codes, it still has just a 404, check the reward is -1.0 again + assert reward.calculate(state, last_action_response) == -1.0 + + +class TestWebpageUnavailabilitySticky: + def test_non_sticky(self): + reward = WebpageUnavailablePenalty("computer", sticky=False) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="success", data={}) + browser_history.append({"outcome": 200}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is not sticky, it goes back to 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + def test_sticky(self): + reward = WebpageUnavailablePenalty("computer", sticky=True) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="success", data={}) + browser_history.append({"outcome": 200}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is sticky, it stays at 1.0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "WebBrowser", "execute"] + response = RequestResponse(status="failure", data={}) + browser_history.append({"outcome": 404}) + state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + +class TestGreenAdminDatabaseUnreachableSticky: + def test_non_sticky(self): + reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is not sticky, it goes back to 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + browser_history = [] + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + def test_sticky(self): + reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True) + + # no response codes yet, reward is 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 0 + + # agent did a successful fetch + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # THE IMPORTANT BIT + # agent did nothing, because reward is not sticky, it goes back to 0 + action, params, request = "DO_NOTHING", {}, ["DONOTHING"] + response = RequestResponse(status="success", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == 1.0 + + # agent fails to fetch, get a -1.0 reward + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 + + # agent fails again to fetch, get a -1.0 reward again + action = "NODE_APPLICATION_EXECUTE" + params = {"node_id": 0, "application_id": 0} + request = ["network", "node", "computer", "application", "DatabaseClient", "execute"] + response = RequestResponse(status="failure", data={}) + state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} + last_action_response = AgentHistoryItem( + timestep=0, action=action, parameters=params, request=request, response=response + ) + assert reward.calculate(state, last_action_response) == -1.0 From 15b7334f05411957d37ed8a19f7a57cdda0e59a5 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 15:34:50 +0100 Subject: [PATCH 4/5] #2736 - Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c63b114..5aba9e6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. +- It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty` ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) From 1833dc39468fd5673869149a388b070b01a97693 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Aug 2024 10:41:40 +0100 Subject: [PATCH 5/5] #2736 - typo fixes --- src/primaite/game/agent/rewards.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 73bc7b11..b97b7c5a 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -208,8 +208,8 @@ class WebServer404Penalty(AbstractReward): self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 - self.reward = 0 - else: # skip calculating if sticky and no new codes. insted, reuse last step's value + self.reward = 0.0 + else: # skip calculating if sticky and no new codes. instead, reuse last step's value pass return self.reward