Merged PR 508: Add option for rewards to be instantaneous

## Summary
* Changed how `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, and `WebServer404Penalty` work.
* They can now be configured with `sticky: false` in the yaml
* which means they no longer retain a positive/negative value after a successful/failed request, if the agent goes on to do nothing the next step
* refactored the calculate methods to better align with those rewards depending the previous action
* changed what is returned by some of the `describe_state` methods of sim components. They had legacy methods of returning the most recent success code which is no longer needed since the introduction of agent history

## Test process
Existing tests pass, new tests added

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [X] attended to any **TO-DOs** left in the code

Related work items: #2736
This commit is contained in:
Marek Wolan
2024-08-20 10:40:43 +00:00
8 changed files with 441 additions and 109 deletions

View File

@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
- It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty`
- Node observations can now be configured to show the number of active local and remote logins.
### Fixed

View File

@@ -171,14 +171,20 @@ class DatabaseFileIntegrity(AbstractReward):
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_hostname: str, service_name: str) -> None:
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
"""Initialise the reward component.
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
@@ -188,16 +194,25 @@ class WebServer404Penalty(AbstractReward):
:type state: Dict
"""
web_service_state = access_from_nested_dict(state, self.location_in_state)
# if webserver is no longer installed on the node, return 0
if web_service_state is NOT_PRESENT_IN_STATE:
return 0.0
most_recent_return_code = web_service_state["last_response_status_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
if most_recent_return_code == 200:
return 1.0
elif most_recent_return_code == 404:
return -1.0
else:
return 0.0
codes = web_service_state.get("response_codes_this_timestep")
if codes:
def status2rew(status: int) -> int:
"""Map status codes to reward values."""
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0
self.reward = 0.0
else: # skip calculating if sticky and no new codes. instead, reuse last step's value
pass
return self.reward
@classmethod
def from_config(cls, config: Dict) -> "WebServer404Penalty":
@@ -217,23 +232,29 @@ class WebServer404Penalty(AbstractReward):
)
_LOGGER.warning(msg)
raise ValueError(msg)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, service_name=service_name)
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
class WebpageUnavailablePenalty(AbstractReward):
"""Penalises the agent when the web browser fails to fetch a webpage."""
def __init__(self, node_hostname: str) -> None:
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -243,31 +264,43 @@ class WebpageUnavailablePenalty(AbstractReward):
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request did actually go through, then check if the webpage also loaded
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
if web_browser_state is NOT_PRESENT_IN_STATE:
self.reward = 0.0
# check if the most recent action was to request the webpage
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
"application",
"WebBrowser",
"execute",
]
# skip calculating if sticky and no new codes, reusing last step value
if not request_attempted and self.sticky:
return self.reward
if last_action_response.response.status != "success":
self.reward = -1.0
elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]:
_LOGGER.debug(
"Web browser reward could not be calculated because the web browser history on node",
f"{self._node} was not reported in the simulation state. Returning 0.0",
)
return 0.0 # 0 if the web browser cannot be found
if not web_browser_state["history"]:
return 0.0 # 0 if no requests have been attempted yet
outcome = web_browser_state["history"][-1]["outcome"]
if outcome == "PENDING":
return 0.0 # 0 if a request was attempted but not yet resolved
elif outcome == 200:
return 1.0 # 1 for successful request
else: # includes failure codes and SERVER_UNREACHABLE
return -1.0 # -1 for failure
self.reward = 0.0
else:
outcome = web_browser_state["history"][-1]["outcome"]
if outcome == "PENDING":
self.reward = 0.0 # 0 if a request was attempted but not yet resolved
elif outcome == 200:
self.reward = 1.0 # 1 for successful request
else: # includes failure codes and SERVER_UNREACHABLE
self.reward = -1.0 # -1 for failure
return self.reward
@classmethod
def from_config(cls, config: dict) -> AbstractReward:
@@ -278,22 +311,28 @@ class WebpageUnavailablePenalty(AbstractReward):
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str) -> None:
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -310,22 +349,26 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
:return: Reward value
:rtype: float
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
"application",
"DatabaseClient",
"execute",
]
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
if request_attempted: # if agent makes request, always recalculate fresh value
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
elif not self.sticky: # if no new request and not sticky, set reward to 0
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
self.reward = 0.0
else: # if no new request and sticky, reuse reward value from last step
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
pass
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"}
return 0.0
last_connection_successful = db_state["last_connection_successful"]
last_action_response.reward_info = {"last_connection_successful": last_connection_successful}
return 1.0 if last_connection_successful else -1.0
return self.reward
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
@@ -336,7 +379,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class SharedReward(AbstractReward):

View File

@@ -73,11 +73,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
_last_connection_successful: Optional[bool] = None
_query_success_tracker: Dict[str, bool] = {}
"""Keep track of connections that were established or verified during this step. Used for rewards."""
last_query_response: Optional[Dict] = None
"""Keep track of the latest query response. Used to determine rewards."""
_server_connection_id: Optional[str] = None
"""Connection ID to the Database Server."""
client_connections: Dict[str, DatabaseClientConnection] = {}
@@ -135,8 +132,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
# list of connections that were established or verified during this step.
state["last_connection_successful"] = self._last_connection_successful
return state
def show(self, markdown: bool = False):
@@ -226,13 +221,11 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
f"Using connection id {database_client_connection}"
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
)
self._last_connection_successful = False
return None
else:
self.sys_log.info(
@@ -357,10 +350,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
success = self._query_success_tracker.get(query_id)
if success:
self.sys_log.info(f"{self.name}: Query successful {sql}")
self._last_connection_successful = True
return True
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
self._last_connection_successful = False
return False
else:
software_manager: SoftwareManager = self.software_manager
@@ -390,9 +381,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
if not self.native_connection:
return False
# reset last query response
self.last_query_response = None
uuid = str(uuid4())
self._query_success_tracker[uuid] = False
return self.native_connection.query(sql)
@@ -416,7 +404,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
)
elif payload["type"] == "sql":
self.last_query_response = payload
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from primaite import getLogger
@@ -22,7 +22,7 @@ _LOGGER = getLogger(__name__)
class WebServer(Service):
"""Class used to represent a Web Server Service in simulation."""
last_response_status_code: Optional[HttpStatusCode] = None
response_codes_this_timestep: List[HttpStatusCode] = []
def describe_state(self) -> Dict:
"""
@@ -34,11 +34,19 @@ class WebServer(Service):
:rtype: Dict
"""
state = super().describe_state()
state["last_response_status_code"] = (
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
)
state["response_codes_this_timestep"] = [code.value for code in self.response_codes_this_timestep]
return state
def pre_timestep(self, timestep: int) -> None:
"""
Logic to execute at the start of the timestep - clear the observation-related attributes.
:param timestep: the current timestep in the episode.
:type timestep: int
"""
self.response_codes_this_timestep = []
return super().pre_timestep(timestep)
def __init__(self, **kwargs):
kwargs["name"] = "WebServer"
kwargs["protocol"] = IPProtocol.TCP
@@ -89,7 +97,7 @@ class WebServer(Service):
self.send(payload=response, session_id=session_id)
# return true if response is OK
self.last_response_status_code = response.status_code
self.response_codes_this_timestep.append(response.status_code)
return response.status_code == HttpStatusCode.OK
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:

View File

@@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.network.router import ACLAction,
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from tests import TEST_ASSETS_ROOT
from tests.conftest import ControlledAgent
@@ -19,32 +20,30 @@ from tests.conftest import ControlledAgent
def test_WebpageUnavailablePenalty(game_and_agent):
"""Test that we get the right reward for failing to fetch a website."""
# set up the scenario, configure the web browser to the correct url
game, agent = game_and_agent
agent: ControlledAgent
comp = WebpageUnavailablePenalty(node_hostname="client_1")
agent.reward_function.register_component(comp, 0.7)
action = ("DONOTHING", {})
agent.store_action(action)
game.step()
# client 1 has not attempted to fetch webpage yet!
assert agent.reward_function.current_reward == 0.0
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser = client_1.software_manager.software.get("WebBrowser")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
assert browser.get_webpage()
action = ("DONOTHING", {})
agent.store_action(action)
agent.reward_function.register_component(comp, 0.7)
# Check that before trying to fetch the webpage, the reward is 0.0
agent.store_action(("DONOTHING", {}))
game.step()
assert agent.reward_function.current_reward == 0.0
# Check that successfully fetching the webpage yields a reward of 0.7
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
game.step()
assert agent.reward_function.current_reward == 0.7
# Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP)
assert not browser.get_webpage()
agent.store_action(action)
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
game.step()
assert agent.reward_function.current_reward == -0.7
@@ -70,34 +69,29 @@ def test_uc2_rewards(game_and_agent):
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
response = db_client.apply_request(
[
"execute",
]
)
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
response = game.simulation.apply_request(request)
state = game.get_sim_state()
ahi = AgentHistoryItem(
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={},
request=["execute"],
response=response,
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == 1.0
assert ahi.reward_info == {"last_connection_successful": True}
assert ahi.reward_info == {"connection_attempt_status": "success"}
router.acl.remove_rule(position=2)
db_client.apply_request(
[
"execute",
]
)
response = game.simulation.apply_request(request)
state = game.get_sim_state()
reward_value = comp.calculate(state, last_action_response=ahi)
ahi = AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
)
reward_value = comp.calculate(
state,
last_action_response=ahi,
)
assert reward_value == -1.0
assert ahi.reward_info == {"last_connection_successful": False}
assert ahi.reward_info == {"connection_attempt_status": "failure"}
def test_shared_reward():

View File

@@ -146,7 +146,6 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
assert green_db_connection.query("SELECT")
assert green_db_client.last_query_response.get("status_code") == 200
data_manipulation_bot.port_scan_p_of_success = 1
data_manipulation_bot.data_manipulation_p_of_success = 1
@@ -155,4 +154,3 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
assert green_db_connection.query("SELECT") is False
assert green_db_client.last_query_response.get("status_code") != 200

View File

@@ -103,7 +103,7 @@ def test_ransomware_script_attack(ransomware_script_and_db_server):
def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client):
"""Test to see show that the database service still operate"""
"""Test to show that the database service still operates after corruption"""
network: Network = ransomware_script_db_server_green_client
client_1: Computer = network.get_node_by_hostname("client_1")
@@ -111,17 +111,18 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
client_2: Computer = network.get_node_by_hostname("client_2")
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
green_db_client.connect()
green_db_client_connection: DatabaseClientConnection = green_db_client.get_new_connection()
server: Server = network.get_node_by_hostname("server_1")
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
assert green_db_client_connection.query("SELECT")
assert green_db_client.last_query_response.get("status_code") == 200
assert green_db_client.query("SELECT") is True
ransomware_script_application.attack()
network.apply_timestep(0)
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
assert green_db_client_connection.query("SELECT") is True
assert green_db_client.last_query_response.get("status_code") == 200
assert green_db_client.query("SELECT") is True # Still operates but now the data field of response is empty

View File

@@ -0,0 +1,299 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import (
GreenAdminDatabaseUnreachablePenalty,
WebpageUnavailablePenalty,
WebServer404Penalty,
)
from primaite.interface.request import RequestResponse
class TestWebServer404PenaltySticky:
def test_non_sticky(self):
reward = WebServer404Penalty("computer", "WebService", sticky=False)
# no response codes yet, reward is 0
codes = []
state = {
"network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}}
}
last_action_response = None
assert reward.calculate(state, last_action_response) == 0
# update codes (by reference), 200 response code is now present
codes.append(200)
assert reward.calculate(state, last_action_response) == 1.0
# THE IMPORTANT BIT
# update codes (by reference), to make it empty again, reward goes back to 0
codes.pop()
assert reward.calculate(state, last_action_response) == 0.0
# update codes (by reference), 404 response code is now present, reward = -1.0
codes.append(404)
assert reward.calculate(state, last_action_response) == -1.0
# don't update codes, it still has just a 404, check the reward is -1.0 again
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = WebServer404Penalty("computer", "WebService", sticky=True)
# no response codes yet, reward is 0
codes = []
state = {
"network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}}
}
last_action_response = None
assert reward.calculate(state, last_action_response) == 0
# update codes (by reference), 200 response code is now present
codes.append(200)
assert reward.calculate(state, last_action_response) == 1.0
# THE IMPORTANT BIT
# update codes (by reference), to make it empty again, reward remains at 1.0 because it's sticky
codes.pop()
assert reward.calculate(state, last_action_response) == 1.0
# update codes (by reference), 404 response code is now present, reward = -1.0
codes.append(404)
assert reward.calculate(state, last_action_response) == -1.0
# don't update codes, it still has just a 404, check the reward is -1.0 again
assert reward.calculate(state, last_action_response) == -1.0
class TestWebpageUnavailabilitySticky:
def test_non_sticky(self):
reward = WebpageUnavailablePenalty("computer", sticky=False)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 0
# agent did a successful fetch
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
response = RequestResponse(status="success", data={})
browser_history.append({"outcome": 200})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 1.0
# THE IMPORTANT BIT
# agent did nothing, because reward is not sticky, it goes back to 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 0.0
# agent fails to fetch, get a -1.0 reward
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
response = RequestResponse(status="failure", data={})
browser_history.append({"outcome": 404})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0
# agent fails again to fetch, get a -1.0 reward again
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
response = RequestResponse(status="failure", data={})
browser_history.append({"outcome": 404})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = WebpageUnavailablePenalty("computer", sticky=True)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 0
# agent did a successful fetch
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
response = RequestResponse(status="success", data={})
browser_history.append({"outcome": 200})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 1.0
# THE IMPORTANT BIT
# agent did nothing, because reward is sticky, it stays at 1.0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 1.0
# agent fails to fetch, get a -1.0 reward
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
response = RequestResponse(status="failure", data={})
browser_history.append({"outcome": 404})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0
# agent fails again to fetch, get a -1.0 reward again
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
response = RequestResponse(status="failure", data={})
browser_history.append({"outcome": 404})
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0
class TestGreenAdminDatabaseUnreachableSticky:
def test_non_sticky(self):
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 0
# agent did a successful fetch
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 1.0
# THE IMPORTANT BIT
# agent did nothing, because reward is not sticky, it goes back to 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 0.0
# agent fails to fetch, get a -1.0 reward
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
response = RequestResponse(status="failure", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0
# agent fails again to fetch, get a -1.0 reward again
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
response = RequestResponse(status="failure", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 0
# agent did a successful fetch
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 1.0
# THE IMPORTANT BIT
# agent did nothing, because reward is not sticky, it goes back to 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == 1.0
# agent fails to fetch, get a -1.0 reward
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
response = RequestResponse(status="failure", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0
# agent fails again to fetch, get a -1.0 reward again
action = "NODE_APPLICATION_EXECUTE"
params = {"node_id": 0, "application_id": 0}
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
response = RequestResponse(status="failure", data={})
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
)
assert reward.calculate(state, last_action_response) == -1.0