#2736 - implement instantaneous rewards

This commit is contained in:
Marek Wolan
2024-08-19 10:17:39 +01:00
parent c886d4b014
commit 05f9751fa8
3 changed files with 104 additions and 64 deletions

View File

@@ -151,14 +151,20 @@ class DatabaseFileIntegrity(AbstractReward):
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_hostname: str, service_name: str) -> None:
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
"""Initialise the reward component.
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
@@ -168,16 +174,21 @@ class WebServer404Penalty(AbstractReward):
:type state: Dict
"""
web_service_state = access_from_nested_dict(state, self.location_in_state)
# if webserver is no longer installed on the node, return 0
if web_service_state is NOT_PRESENT_IN_STATE:
return 0.0
most_recent_return_code = web_service_state["last_response_status_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
if most_recent_return_code == 200:
return 1.0
elif most_recent_return_code == 404:
return -1.0
else:
return 0.0
codes = web_service_state.get("response_codes_this_timestep")
if codes or not self.sticky: # skip calculating if sticky and no new codes. Insted, reuse last step's value.
def status2rew(status: int) -> int:
"""Map status codes to reward values."""
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
return self.reward
@classmethod
def from_config(cls, config: Dict) -> "WebServer404Penalty":
@@ -197,23 +208,29 @@ class WebServer404Penalty(AbstractReward):
)
_LOGGER.warning(msg)
raise ValueError(msg)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, service_name=service_name)
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
class WebpageUnavailablePenalty(AbstractReward):
"""Penalises the agent when the web browser fails to fetch a webpage."""
def __init__(self, node_hostname: str) -> None:
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -223,31 +240,46 @@ class WebpageUnavailablePenalty(AbstractReward):
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request did actually go through, then check if the webpage also loaded
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
if web_browser_state is NOT_PRESENT_IN_STATE:
self.reward = 0.0
# check if the most recent action was to request the webpage
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
"application",
"WebBrowser",
"execute",
]
if (
not request_attempted and self.sticky
): # skip calculating if sticky and no new codes, reusing last step value
return self.reward
if last_action_response.response.status != "success":
self.reward = -1.0
#
elif web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
_LOGGER.debug(
"Web browser reward could not be calculated because the web browser history on node",
f"{self._node} was not reported in the simulation state. Returning 0.0",
)
return 0.0 # 0 if the web browser cannot be found
if not web_browser_state["history"]:
return 0.0 # 0 if no requests have been attempted yet
self.reward = 0.0
elif not web_browser_state["history"]:
self.reward = 0.0 # 0 if no requests have been attempted yet
outcome = web_browser_state["history"][-1]["outcome"]
if outcome == "PENDING":
return 0.0 # 0 if a request was attempted but not yet resolved
self.reward = 0.0 # 0 if a request was attempted but not yet resolved
elif outcome == 200:
return 1.0 # 1 for successful request
self.reward = 1.0 # 1 for successful request
else: # includes failure codes and SERVER_UNREACHABLE
return -1.0 # -1 for failure
self.reward = -1.0 # -1 for failure
return self.reward
@classmethod
def from_config(cls, config: dict) -> AbstractReward:
@@ -258,22 +290,28 @@ class WebpageUnavailablePenalty(AbstractReward):
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str) -> None:
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -284,25 +322,29 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
db_state = access_from_nested_dict(state, self.location_in_state)
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
elif last_connection_successful is True:
return 1.0
return 0.0
self.reward = 0.0
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
"application",
"DatabaseClient",
"execute",
]
if (
not request_attempted and self.sticky
): # skip calculating if sticky and no new codes, reusing last step value
return self.reward
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
return self.reward
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
@@ -313,7 +355,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class SharedReward(AbstractReward):

View File

@@ -73,7 +73,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
_last_connection_successful: Optional[bool] = None
_query_success_tracker: Dict[str, bool] = {}
"""Keep track of connections that were established or verified during this step. Used for rewards."""
last_query_response: Optional[Dict] = None
@@ -135,8 +134,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
# list of connections that were established or verified during this step.
state["last_connection_successful"] = self._last_connection_successful
return state
def show(self, markdown: bool = False):
@@ -226,13 +223,11 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
f"Using connection id {database_client_connection}"
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
)
self._last_connection_successful = False
return None
else:
self.sys_log.info(
@@ -357,10 +352,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
success = self._query_success_tracker.get(query_id)
if success:
self.sys_log.info(f"{self.name}: Query successful {sql}")
self._last_connection_successful = True
return True
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
self._last_connection_successful = False
return False
else:
software_manager: SoftwareManager = self.software_manager
@@ -390,9 +383,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
if not self.native_connection:
return False
# reset last query response
self.last_query_response = None
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self.native_connection.query(sql)
@@ -416,7 +406,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
)
elif payload["type"] == "sql":
self.last_query_response = payload
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from primaite import getLogger
@@ -22,7 +22,7 @@ _LOGGER = getLogger(__name__)
class WebServer(Service):
"""Class used to represent a Web Server Service in simulation."""
last_response_status_code: Optional[HttpStatusCode] = None
response_codes_this_timestep: List[HttpStatusCode] = []
def describe_state(self) -> Dict:
"""
@@ -34,11 +34,19 @@ class WebServer(Service):
:rtype: Dict
"""
state = super().describe_state()
state["last_response_status_code"] = (
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
)
state["response_codes_this_timestep"] = [code.value for code in self.response_codes_this_timestep]
return state
def pre_timestep(self, timestep: int) -> None:
"""
Logic to execute at the start of the timestep - clear the observation-related attributes.
:param timestep: the current timestep in the episode.
:type timestep: int
"""
self.response_codes_this_timestep = []
return super().pre_timestep(timestep)
def __init__(self, **kwargs):
kwargs["name"] = "WebServer"
kwargs["protocol"] = IPProtocol.TCP
@@ -89,7 +97,7 @@ class WebServer(Service):
self.send(payload=response, session_id=session_id)
# return true if response is OK
self.last_response_status_code = response.status_code
self.response_codes_this_timestep.append(response.status_code)
return response.status_code == HttpStatusCode.OK
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket: