diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index c959ee5b..00374791 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -151,14 +151,20 @@ class DatabaseFileIntegrity(AbstractReward): class WebServer404Penalty(AbstractReward): """Reward function component which penalises the agent when the web server returns a 404 error.""" - def __init__(self, node_hostname: str, service_name: str) -> None: + def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: """Initialise the reward component. :param node_hostname: Hostname of the node which contains the web server service. :type node_hostname: str :param service_name: Name of the web server service. :type service_name: str + :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + the reward if there were any responses this timestep. + :type sticky: bool """ + self.sticky: bool = sticky + self.reward: float = 0.0 + """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: @@ -168,16 +174,21 @@ class WebServer404Penalty(AbstractReward): :type state: Dict """ web_service_state = access_from_nested_dict(state, self.location_in_state) + + # if webserver is no longer installed on the node, return 0 if web_service_state is NOT_PRESENT_IN_STATE: return 0.0 - most_recent_return_code = web_service_state["last_response_status_code"] - # TODO: reward needs to use the current web state. Observation should return web state at the time of last scan. - if most_recent_return_code == 200: - return 1.0 - elif most_recent_return_code == 404: - return -1.0 - else: - return 0.0 + + codes = web_service_state.get("response_codes_this_timestep") + if codes or not self.sticky: # skip calculating if sticky and no new codes. Insted, reuse last step's value. + + def status2rew(status: int) -> int: + """Map status codes to reward values.""" + return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 + + self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average + + return self.reward @classmethod def from_config(cls, config: Dict) -> "WebServer404Penalty": @@ -197,23 +208,29 @@ class WebServer404Penalty(AbstractReward): ) _LOGGER.warning(msg) raise ValueError(msg) + sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, service_name=service_name) + return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) class WebpageUnavailablePenalty(AbstractReward): """Penalises the agent when the web browser fails to fetch a webpage.""" - def __init__(self, node_hostname: str) -> None: + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. :param node_hostname: Hostname of the node which has the web browser. :type node_hostname: str + :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + the reward if there were any responses this timestep. + :type sticky: bool """ self._node: str = node_hostname self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - self._last_request_failed: bool = False + self.sticky: bool = sticky + self.reward: float = 0.0 + """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -223,31 +240,46 @@ class WebpageUnavailablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last webpage had a 200 status code, because there has been an unsuccessful request since. """ - if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]: - self._last_request_failed = last_action_response.response.status != "success" - - # if agent couldn't even get as far as sending the request (because for example the node was off), then - # apply a penalty - if self._last_request_failed: - return -1.0 - - # If the last request did actually go through, then check if the webpage also loaded web_browser_state = access_from_nested_dict(state, self.location_in_state) - if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: + + if web_browser_state is NOT_PRESENT_IN_STATE: + self.reward = 0.0 + + # check if the most recent action was to request the webpage + request_attempted = last_action_response.request == [ + "network", + "node", + self._node, + "application", + "WebBrowser", + "execute", + ] + + if ( + not request_attempted and self.sticky + ): # skip calculating if sticky and no new codes, reusing last step value + return self.reward + + if last_action_response.response.status != "success": + self.reward = -1.0 + # + elif web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", f"{self._node} was not reported in the simulation state. Returning 0.0", ) - return 0.0 # 0 if the web browser cannot be found - if not web_browser_state["history"]: - return 0.0 # 0 if no requests have been attempted yet + self.reward = 0.0 + elif not web_browser_state["history"]: + self.reward = 0.0 # 0 if no requests have been attempted yet outcome = web_browser_state["history"][-1]["outcome"] if outcome == "PENDING": - return 0.0 # 0 if a request was attempted but not yet resolved + self.reward = 0.0 # 0 if a request was attempted but not yet resolved elif outcome == 200: - return 1.0 # 1 for successful request + self.reward = 1.0 # 1 for successful request else: # includes failure codes and SERVER_UNREACHABLE - return -1.0 # -1 for failure + self.reward = -1.0 # -1 for failure + + return self.reward @classmethod def from_config(cls, config: dict) -> AbstractReward: @@ -258,22 +290,28 @@ class WebpageUnavailablePenalty(AbstractReward): :type config: Dict """ node_hostname = config.get("node_hostname") - return cls(node_hostname=node_hostname) + sticky = config.get("sticky", True) + return cls(node_hostname=node_hostname, sticky=sticky) class GreenAdminDatabaseUnreachablePenalty(AbstractReward): """Penalises the agent when the green db clients fail to connect to the database.""" - def __init__(self, node_hostname: str) -> None: + def __init__(self, node_hostname: str, sticky: bool = True) -> None: """ Initialise the reward component. :param node_hostname: Hostname of the node where the database client sits. :type node_hostname: str + :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate + the reward if there were any responses this timestep. + :type sticky: bool """ self._node: str = node_hostname self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - self._last_request_failed: bool = False + self.sticky: bool = sticky + self.reward: float = 0.0 + """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -284,25 +322,29 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): request returned was able to connect to the database server, because there has been an unsuccessful request since. """ - if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: - self._last_request_failed = last_action_response.response.status != "success" - - # if agent couldn't even get as far as sending the request (because for example the node was off), then - # apply a penalty - if self._last_request_failed: - return -1.0 + db_state = access_from_nested_dict(state, self.location_in_state) # If the last request was actually sent, then check if the connection was established. - db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") - return 0.0 - last_connection_successful = db_state["last_connection_successful"] - if last_connection_successful is False: - return -1.0 - elif last_connection_successful is True: - return 1.0 - return 0.0 + self.reward = 0.0 + + request_attempted = last_action_response.request == [ + "network", + "node", + self._node, + "application", + "DatabaseClient", + "execute", + ] + + if ( + not request_attempted and self.sticky + ): # skip calculating if sticky and no new codes, reusing last step value + return self.reward + + self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 + return self.reward @classmethod def from_config(cls, config: Dict) -> AbstractReward: @@ -313,7 +355,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): :type config: Dict """ node_hostname = config.get("node_hostname") - return cls(node_hostname=node_hostname) + sticky = config.get("sticky", True) + return cls(node_hostname=node_hostname, sticky=sticky) class SharedReward(AbstractReward): diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index e6cfa343..933afadf 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -73,7 +73,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): server_ip_address: Optional[IPv4Address] = None server_password: Optional[str] = None - _last_connection_successful: Optional[bool] = None _query_success_tracker: Dict[str, bool] = {} """Keep track of connections that were established or verified during this step. Used for rewards.""" last_query_response: Optional[Dict] = None @@ -135,8 +134,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): :return: A dictionary representing the current state. """ state = super().describe_state() - # list of connections that were established or verified during this step. - state["last_connection_successful"] = self._last_connection_successful return state def show(self, markdown: bool = False): @@ -226,13 +223,11 @@ class DatabaseClient(Application, identifier="DatabaseClient"): f"Using connection id {database_client_connection}" ) self.connected = True - self._last_connection_successful = True return database_client_connection else: self.sys_log.info( f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined" ) - self._last_connection_successful = False return None else: self.sys_log.info( @@ -357,10 +352,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"): success = self._query_success_tracker.get(query_id) if success: self.sys_log.info(f"{self.name}: Query successful {sql}") - self._last_connection_successful = True return True self.sys_log.error(f"{self.name}: Unable to run query {sql}") - self._last_connection_successful = False return False else: software_manager: SoftwareManager = self.software_manager @@ -390,9 +383,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): if not self.native_connection: return False - # reset last query response - self.last_query_response = None - uuid = str(uuid4()) self._query_success_tracker[uuid] = False return self.native_connection.query(sql) @@ -416,7 +406,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): connection_id=connection_id, connection_request_id=payload["connection_request_id"] ) elif payload["type"] == "sql": - self.last_query_response = payload query_id = payload.get("uuid") status_code = payload.get("status_code") self._query_success_tracker[query_id] = status_code == 200 diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 6f6fa335..4fc64e1f 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urlparse from primaite import getLogger @@ -22,7 +22,7 @@ _LOGGER = getLogger(__name__) class WebServer(Service): """Class used to represent a Web Server Service in simulation.""" - last_response_status_code: Optional[HttpStatusCode] = None + response_codes_this_timestep: List[HttpStatusCode] = [] def describe_state(self) -> Dict: """ @@ -34,11 +34,19 @@ class WebServer(Service): :rtype: Dict """ state = super().describe_state() - state["last_response_status_code"] = ( - self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None - ) + state["response_codes_this_timestep"] = [code.value for code in self.response_codes_this_timestep] return state + def pre_timestep(self, timestep: int) -> None: + """ + Logic to execute at the start of the timestep - clear the observation-related attributes. + + :param timestep: the current timestep in the episode. + :type timestep: int + """ + self.response_codes_this_timestep = [] + return super().pre_timestep(timestep) + def __init__(self, **kwargs): kwargs["name"] = "WebServer" kwargs["protocol"] = IPProtocol.TCP @@ -89,7 +97,7 @@ class WebServer(Service): self.send(payload=response, session_id=session_id) # return true if response is OK - self.last_response_status_code = response.status_code + self.response_codes_this_timestep.append(response.status_code) return response.status_code == HttpStatusCode.OK def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: