#2736 - implement instantaneous rewards
This commit is contained in:
@@ -151,14 +151,20 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
def __init__(self, node_hostname: str, service_name: str) -> None:
|
||||
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node which contains the web server service.
|
||||
:type node_hostname: str
|
||||
:param service_name: Name of the web server service.
|
||||
:type service_name: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
@@ -168,16 +174,21 @@ class WebServer404Penalty(AbstractReward):
|
||||
:type state: Dict
|
||||
"""
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
# if webserver is no longer installed on the node, return 0
|
||||
if web_service_state is NOT_PRESENT_IN_STATE:
|
||||
return 0.0
|
||||
most_recent_return_code = web_service_state["last_response_status_code"]
|
||||
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
|
||||
if most_recent_return_code == 200:
|
||||
return 1.0
|
||||
elif most_recent_return_code == 404:
|
||||
return -1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
codes = web_service_state.get("response_codes_this_timestep")
|
||||
if codes or not self.sticky: # skip calculating if sticky and no new codes. Insted, reuse last step's value.
|
||||
|
||||
def status2rew(status: int) -> int:
|
||||
"""Map status codes to reward values."""
|
||||
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
|
||||
|
||||
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "WebServer404Penalty":
|
||||
@@ -197,23 +208,29 @@ class WebServer404Penalty(AbstractReward):
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
raise ValueError(msg)
|
||||
sticky = config.get("sticky", True)
|
||||
|
||||
return cls(node_hostname=node_hostname, service_name=service_name)
|
||||
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
|
||||
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
|
||||
def __init__(self, node_hostname: str) -> None:
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node which has the web browser.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -223,31 +240,46 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last webpage
|
||||
had a 200 status code, because there has been an unsuccessful request since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request did actually go through, then check if the webpage also loaded
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
|
||||
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE:
|
||||
self.reward = 0.0
|
||||
|
||||
# check if the most recent action was to request the webpage
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self._node,
|
||||
"application",
|
||||
"WebBrowser",
|
||||
"execute",
|
||||
]
|
||||
|
||||
if (
|
||||
not request_attempted and self.sticky
|
||||
): # skip calculating if sticky and no new codes, reusing last step value
|
||||
return self.reward
|
||||
|
||||
if last_action_response.response.status != "success":
|
||||
self.reward = -1.0
|
||||
#
|
||||
elif web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
|
||||
_LOGGER.debug(
|
||||
"Web browser reward could not be calculated because the web browser history on node",
|
||||
f"{self._node} was not reported in the simulation state. Returning 0.0",
|
||||
)
|
||||
return 0.0 # 0 if the web browser cannot be found
|
||||
if not web_browser_state["history"]:
|
||||
return 0.0 # 0 if no requests have been attempted yet
|
||||
self.reward = 0.0
|
||||
elif not web_browser_state["history"]:
|
||||
self.reward = 0.0 # 0 if no requests have been attempted yet
|
||||
outcome = web_browser_state["history"][-1]["outcome"]
|
||||
if outcome == "PENDING":
|
||||
return 0.0 # 0 if a request was attempted but not yet resolved
|
||||
self.reward = 0.0 # 0 if a request was attempted but not yet resolved
|
||||
elif outcome == 200:
|
||||
return 1.0 # 1 for successful request
|
||||
self.reward = 1.0 # 1 for successful request
|
||||
else: # includes failure codes and SERVER_UNREACHABLE
|
||||
return -1.0 # -1 for failure
|
||||
self.reward = -1.0 # -1 for failure
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> AbstractReward:
|
||||
@@ -258,22 +290,28 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
return cls(node_hostname=node_hostname)
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
def __init__(self, node_hostname: str) -> None:
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -284,25 +322,29 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
request returned was able to connect to the database server, because there has been an unsuccessful request
|
||||
since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
return 0.0
|
||||
last_connection_successful = db_state["last_connection_successful"]
|
||||
if last_connection_successful is False:
|
||||
return -1.0
|
||||
elif last_connection_successful is True:
|
||||
return 1.0
|
||||
return 0.0
|
||||
self.reward = 0.0
|
||||
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self._node,
|
||||
"application",
|
||||
"DatabaseClient",
|
||||
"execute",
|
||||
]
|
||||
|
||||
if (
|
||||
not request_attempted and self.sticky
|
||||
): # skip calculating if sticky and no new codes, reusing last step value
|
||||
return self.reward
|
||||
|
||||
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractReward:
|
||||
@@ -313,7 +355,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
return cls(node_hostname=node_hostname)
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward):
|
||||
|
||||
@@ -73,7 +73,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
_last_connection_successful: Optional[bool] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
last_query_response: Optional[Dict] = None
|
||||
@@ -135,8 +134,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
# list of connections that were established or verified during this step.
|
||||
state["last_connection_successful"] = self._last_connection_successful
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -226,13 +223,11 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
f"Using connection id {database_client_connection}"
|
||||
)
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
else:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
|
||||
)
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
else:
|
||||
self.sys_log.info(
|
||||
@@ -357,10 +352,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
success = self._query_success_tracker.get(query_id)
|
||||
if success:
|
||||
self.sys_log.info(f"{self.name}: Query successful {sql}")
|
||||
self._last_connection_successful = True
|
||||
return True
|
||||
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
|
||||
self._last_connection_successful = False
|
||||
return False
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
@@ -390,9 +383,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
if not self.native_connection:
|
||||
return False
|
||||
|
||||
# reset last query response
|
||||
self.last_query_response = None
|
||||
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self.native_connection.query(sql)
|
||||
@@ -416,7 +406,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
|
||||
)
|
||||
elif payload["type"] == "sql":
|
||||
self.last_query_response = payload
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -22,7 +22,7 @@ _LOGGER = getLogger(__name__)
|
||||
class WebServer(Service):
|
||||
"""Class used to represent a Web Server Service in simulation."""
|
||||
|
||||
last_response_status_code: Optional[HttpStatusCode] = None
|
||||
response_codes_this_timestep: List[HttpStatusCode] = []
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -34,11 +34,19 @@ class WebServer(Service):
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["last_response_status_code"] = (
|
||||
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
|
||||
)
|
||||
state["response_codes_this_timestep"] = [code.value for code in self.response_codes_this_timestep]
|
||||
return state
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Logic to execute at the start of the timestep - clear the observation-related attributes.
|
||||
|
||||
:param timestep: the current timestep in the episode.
|
||||
:type timestep: int
|
||||
"""
|
||||
self.response_codes_this_timestep = []
|
||||
return super().pre_timestep(timestep)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "WebServer"
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
@@ -89,7 +97,7 @@ class WebServer(Service):
|
||||
self.send(payload=response, session_id=session_id)
|
||||
|
||||
# return true if response is OK
|
||||
self.last_response_status_code = response.status_code
|
||||
self.response_codes_this_timestep.append(response.status_code)
|
||||
return response.status_code == HttpStatusCode.OK
|
||||
|
||||
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:
|
||||
|
||||
Reference in New Issue
Block a user