Merged PR 508: Add option for rewards to be instantaneous
## Summary * Changed how `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, and `WebServer404Penalty` work. * They can now be configured with `sticky: false` in the yaml * which means they no longer retain a positive/negative value after a successful/failed request, if the agent goes on to do nothing the next step * refactored the calculate methods to better align with those rewards depending the previous action * changed what is returned by some of the `describe_state` methods of sim components. They had legacy methods of returning the most recent success code which is no longer needed since the introduction of agent history ## Test process Existing tests pass, new tests added ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [X] updated the **change log** - [X] ran **pre-commit** checks for code style - [X] attended to any **TO-DOs** left in the code Related work items: #2736
This commit is contained in:
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
|
||||
- It's now possible to disable stickiness on reward components, meaning their value returns to 0 during timesteps where agent don't issue the corresponding action. Affects `GreenAdminDatabaseUnreachablePenalty`, `WebpageUnavailablePenalty`, `WebServer404Penalty`
|
||||
- Node observations can now be configured to show the number of active local and remote logins.
|
||||
|
||||
### Fixed
|
||||
|
||||
@@ -171,14 +171,20 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
def __init__(self, node_hostname: str, service_name: str) -> None:
|
||||
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node which contains the web server service.
|
||||
:type node_hostname: str
|
||||
:param service_name: Name of the web server service.
|
||||
:type service_name: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
@@ -188,16 +194,25 @@ class WebServer404Penalty(AbstractReward):
|
||||
:type state: Dict
|
||||
"""
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
# if webserver is no longer installed on the node, return 0
|
||||
if web_service_state is NOT_PRESENT_IN_STATE:
|
||||
return 0.0
|
||||
most_recent_return_code = web_service_state["last_response_status_code"]
|
||||
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
|
||||
if most_recent_return_code == 200:
|
||||
return 1.0
|
||||
elif most_recent_return_code == 404:
|
||||
return -1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
codes = web_service_state.get("response_codes_this_timestep")
|
||||
if codes:
|
||||
|
||||
def status2rew(status: int) -> int:
|
||||
"""Map status codes to reward values."""
|
||||
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
|
||||
|
||||
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
|
||||
elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0
|
||||
self.reward = 0.0
|
||||
else: # skip calculating if sticky and no new codes. instead, reuse last step's value
|
||||
pass
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "WebServer404Penalty":
|
||||
@@ -217,23 +232,29 @@ class WebServer404Penalty(AbstractReward):
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
raise ValueError(msg)
|
||||
sticky = config.get("sticky", True)
|
||||
|
||||
return cls(node_hostname=node_hostname, service_name=service_name)
|
||||
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
|
||||
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
|
||||
def __init__(self, node_hostname: str) -> None:
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node which has the web browser.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -243,31 +264,43 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last webpage
|
||||
had a 200 status code, because there has been an unsuccessful request since.
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
|
||||
# If the last request did actually go through, then check if the webpage also loaded
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
|
||||
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE:
|
||||
self.reward = 0.0
|
||||
|
||||
# check if the most recent action was to request the webpage
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self._node,
|
||||
"application",
|
||||
"WebBrowser",
|
||||
"execute",
|
||||
]
|
||||
|
||||
# skip calculating if sticky and no new codes, reusing last step value
|
||||
if not request_attempted and self.sticky:
|
||||
return self.reward
|
||||
|
||||
if last_action_response.response.status != "success":
|
||||
self.reward = -1.0
|
||||
elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]:
|
||||
_LOGGER.debug(
|
||||
"Web browser reward could not be calculated because the web browser history on node",
|
||||
f"{self._node} was not reported in the simulation state. Returning 0.0",
|
||||
)
|
||||
return 0.0 # 0 if the web browser cannot be found
|
||||
if not web_browser_state["history"]:
|
||||
return 0.0 # 0 if no requests have been attempted yet
|
||||
outcome = web_browser_state["history"][-1]["outcome"]
|
||||
if outcome == "PENDING":
|
||||
return 0.0 # 0 if a request was attempted but not yet resolved
|
||||
elif outcome == 200:
|
||||
return 1.0 # 1 for successful request
|
||||
else: # includes failure codes and SERVER_UNREACHABLE
|
||||
return -1.0 # -1 for failure
|
||||
self.reward = 0.0
|
||||
else:
|
||||
outcome = web_browser_state["history"][-1]["outcome"]
|
||||
if outcome == "PENDING":
|
||||
self.reward = 0.0 # 0 if a request was attempted but not yet resolved
|
||||
elif outcome == 200:
|
||||
self.reward = 1.0 # 1 for successful request
|
||||
else: # includes failure codes and SERVER_UNREACHABLE
|
||||
self.reward = -1.0 # -1 for failure
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> AbstractReward:
|
||||
@@ -278,22 +311,28 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
return cls(node_hostname=node_hostname)
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
def __init__(self, node_hostname: str) -> None:
|
||||
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
"""
|
||||
Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node where the database client sits.
|
||||
:type node_hostname: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self._node: str = node_hostname
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -310,22 +349,26 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self._node,
|
||||
"application",
|
||||
"DatabaseClient",
|
||||
"execute",
|
||||
]
|
||||
|
||||
# if agent couldn't even get as far as sending the request (because for example the node was off), then
|
||||
# apply a penalty
|
||||
if self._last_request_failed:
|
||||
return -1.0
|
||||
if request_attempted: # if agent makes request, always recalculate fresh value
|
||||
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
|
||||
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
|
||||
elif not self.sticky: # if no new request and not sticky, set reward to 0
|
||||
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
|
||||
self.reward = 0.0
|
||||
else: # if no new request and sticky, reuse reward value from last step
|
||||
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
|
||||
pass
|
||||
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"}
|
||||
return 0.0
|
||||
last_connection_successful = db_state["last_connection_successful"]
|
||||
last_action_response.reward_info = {"last_connection_successful": last_connection_successful}
|
||||
return 1.0 if last_connection_successful else -1.0
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractReward:
|
||||
@@ -336,7 +379,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
return cls(node_hostname=node_hostname)
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward):
|
||||
|
||||
@@ -73,11 +73,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
_last_connection_successful: Optional[bool] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
last_query_response: Optional[Dict] = None
|
||||
"""Keep track of the latest query response. Used to determine rewards."""
|
||||
_server_connection_id: Optional[str] = None
|
||||
"""Connection ID to the Database Server."""
|
||||
client_connections: Dict[str, DatabaseClientConnection] = {}
|
||||
@@ -135,8 +132,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
# list of connections that were established or verified during this step.
|
||||
state["last_connection_successful"] = self._last_connection_successful
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
@@ -226,13 +221,11 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
f"Using connection id {database_client_connection}"
|
||||
)
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
else:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
|
||||
)
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
else:
|
||||
self.sys_log.info(
|
||||
@@ -357,10 +350,8 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
success = self._query_success_tracker.get(query_id)
|
||||
if success:
|
||||
self.sys_log.info(f"{self.name}: Query successful {sql}")
|
||||
self._last_connection_successful = True
|
||||
return True
|
||||
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
|
||||
self._last_connection_successful = False
|
||||
return False
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
@@ -390,9 +381,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
if not self.native_connection:
|
||||
return False
|
||||
|
||||
# reset last query response
|
||||
self.last_query_response = None
|
||||
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self.native_connection.query(sql)
|
||||
@@ -416,7 +404,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
|
||||
)
|
||||
elif payload["type"] == "sql":
|
||||
self.last_query_response = payload
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -22,7 +22,7 @@ _LOGGER = getLogger(__name__)
|
||||
class WebServer(Service):
|
||||
"""Class used to represent a Web Server Service in simulation."""
|
||||
|
||||
last_response_status_code: Optional[HttpStatusCode] = None
|
||||
response_codes_this_timestep: List[HttpStatusCode] = []
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -34,11 +34,19 @@ class WebServer(Service):
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["last_response_status_code"] = (
|
||||
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
|
||||
)
|
||||
state["response_codes_this_timestep"] = [code.value for code in self.response_codes_this_timestep]
|
||||
return state
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Logic to execute at the start of the timestep - clear the observation-related attributes.
|
||||
|
||||
:param timestep: the current timestep in the episode.
|
||||
:type timestep: int
|
||||
"""
|
||||
self.response_codes_this_timestep = []
|
||||
return super().pre_timestep(timestep)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "WebServer"
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
@@ -89,7 +97,7 @@ class WebServer(Service):
|
||||
self.send(payload=response, session_id=session_id)
|
||||
|
||||
# return true if response is OK
|
||||
self.last_response_status_code = response.status_code
|
||||
self.response_codes_this_timestep.append(response.status_code)
|
||||
return response.status_code == HttpStatusCode.OK
|
||||
|
||||
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:
|
||||
|
||||
@@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.network.router import ACLAction,
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import ControlledAgent
|
||||
@@ -19,32 +20,30 @@ from tests.conftest import ControlledAgent
|
||||
|
||||
def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
"""Test that we get the right reward for failing to fetch a website."""
|
||||
# set up the scenario, configure the web browser to the correct url
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = WebpageUnavailablePenalty(node_hostname="client_1")
|
||||
|
||||
agent.reward_function.register_component(comp, 0.7)
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# client 1 has not attempted to fetch webpage yet!
|
||||
assert agent.reward_function.current_reward == 0.0
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
browser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage()
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
agent.reward_function.register_component(comp, 0.7)
|
||||
|
||||
# Check that before trying to fetch the webpage, the reward is 0.0
|
||||
agent.store_action(("DONOTHING", {}))
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.0
|
||||
|
||||
# Check that successfully fetching the webpage yields a reward of 0.7
|
||||
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.7
|
||||
|
||||
# Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7
|
||||
router: Router = game.simulation.network.get_node_by_hostname("router")
|
||||
router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP)
|
||||
assert not browser.get_webpage()
|
||||
agent.store_action(action)
|
||||
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == -0.7
|
||||
|
||||
@@ -70,34 +69,29 @@ def test_uc2_rewards(game_and_agent):
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
|
||||
|
||||
response = db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
|
||||
response = game.simulation.apply_request(request)
|
||||
state = game.get_sim_state()
|
||||
ahi = AgentHistoryItem(
|
||||
timestep=0,
|
||||
action="NODE_APPLICATION_EXECUTE",
|
||||
parameters={},
|
||||
request=["execute"],
|
||||
response=response,
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
)
|
||||
reward_value = comp.calculate(state, last_action_response=ahi)
|
||||
assert reward_value == 1.0
|
||||
assert ahi.reward_info == {"last_connection_successful": True}
|
||||
assert ahi.reward_info == {"connection_attempt_status": "success"}
|
||||
|
||||
router.acl.remove_rule(position=2)
|
||||
|
||||
db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
response = game.simulation.apply_request(request)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(state, last_action_response=ahi)
|
||||
ahi = AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
)
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=ahi,
|
||||
)
|
||||
assert reward_value == -1.0
|
||||
assert ahi.reward_info == {"last_connection_successful": False}
|
||||
assert ahi.reward_info == {"connection_attempt_status": "failure"}
|
||||
|
||||
|
||||
def test_shared_reward():
|
||||
|
||||
@@ -146,7 +146,6 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
|
||||
assert green_db_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
|
||||
data_manipulation_bot.port_scan_p_of_success = 1
|
||||
data_manipulation_bot.data_manipulation_p_of_success = 1
|
||||
@@ -155,4 +154,3 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
|
||||
assert green_db_connection.query("SELECT") is False
|
||||
assert green_db_client.last_query_response.get("status_code") != 200
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
|
||||
|
||||
def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client):
|
||||
"""Test to see show that the database service still operate"""
|
||||
"""Test to show that the database service still operates after corruption"""
|
||||
network: Network = ransomware_script_db_server_green_client
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
@@ -111,17 +111,18 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
|
||||
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
green_db_client.connect()
|
||||
green_db_client_connection: DatabaseClientConnection = green_db_client.get_new_connection()
|
||||
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
|
||||
assert green_db_client_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
assert green_db_client.query("SELECT") is True
|
||||
|
||||
ransomware_script_application.attack()
|
||||
|
||||
network.apply_timestep(0)
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert green_db_client_connection.query("SELECT") is True
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
assert green_db_client.query("SELECT") is True # Still operates but now the data field of response is empty
|
||||
|
||||
299
tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py
Normal file
299
tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import (
|
||||
GreenAdminDatabaseUnreachablePenalty,
|
||||
WebpageUnavailablePenalty,
|
||||
WebServer404Penalty,
|
||||
)
|
||||
from primaite.interface.request import RequestResponse
|
||||
|
||||
|
||||
class TestWebServer404PenaltySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
state = {
|
||||
"network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}}
|
||||
}
|
||||
last_action_response = None
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# update codes (by reference), 200 response code is now present
|
||||
codes.append(200)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# update codes (by reference), to make it empty again, reward goes back to 0
|
||||
codes.pop()
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# update codes (by reference), 404 response code is now present, reward = -1.0
|
||||
codes.append(404)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# don't update codes, it still has just a 404, check the reward is -1.0 again
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
state = {
|
||||
"network": {"nodes": {"computer": {"services": {"WebService": {"response_codes_this_timestep": codes}}}}}
|
||||
}
|
||||
last_action_response = None
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# update codes (by reference), 200 response code is now present
|
||||
codes.append(200)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# update codes (by reference), to make it empty again, reward remains at 1.0 because it's sticky
|
||||
codes.pop()
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# update codes (by reference), 404 response code is now present, reward = -1.0
|
||||
codes.append(404)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# don't update codes, it still has just a 404, check the reward is -1.0 again
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
|
||||
class TestWebpageUnavailabilitySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history.append({"outcome": 200})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history.append({"outcome": 200})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is sticky, it stays at 1.0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
|
||||
class TestGreenAdminDatabaseUnreachableSticky:
|
||||
def test_non_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
timestep=0, action=action, parameters=params, request=request, response=response
|
||||
)
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
Reference in New Issue
Block a user