#2736 - Fix up broken reward tests
This commit is contained in:
@@ -255,29 +255,26 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
"execute",
|
||||
]
|
||||
|
||||
if (
|
||||
not request_attempted and self.sticky
|
||||
): # skip calculating if sticky and no new codes, reusing last step value
|
||||
# skip calculating if sticky and no new codes, reusing last step value
|
||||
if not request_attempted and self.sticky:
|
||||
return self.reward
|
||||
|
||||
if last_action_response.response.status != "success":
|
||||
self.reward = -1.0
|
||||
#
|
||||
elif web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
|
||||
elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]:
|
||||
_LOGGER.debug(
|
||||
"Web browser reward could not be calculated because the web browser history on node",
|
||||
f"{self._node} was not reported in the simulation state. Returning 0.0",
|
||||
)
|
||||
self.reward = 0.0
|
||||
elif not web_browser_state["history"]:
|
||||
self.reward = 0.0 # 0 if no requests have been attempted yet
|
||||
outcome = web_browser_state["history"][-1]["outcome"]
|
||||
if outcome == "PENDING":
|
||||
self.reward = 0.0 # 0 if a request was attempted but not yet resolved
|
||||
elif outcome == 200:
|
||||
self.reward = 1.0 # 1 for successful request
|
||||
else: # includes failure codes and SERVER_UNREACHABLE
|
||||
self.reward = -1.0 # -1 for failure
|
||||
else:
|
||||
outcome = web_browser_state["history"][-1]["outcome"]
|
||||
if outcome == "PENDING":
|
||||
self.reward = 0.0 # 0 if a request was attempted but not yet resolved
|
||||
elif outcome == 200:
|
||||
self.reward = 1.0 # 1 for successful request
|
||||
else: # includes failure codes and SERVER_UNREACHABLE
|
||||
self.reward = -1.0 # -1 for failure
|
||||
|
||||
return self.reward
|
||||
|
||||
@@ -325,7 +322,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
if db_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
self.reward = 0.0
|
||||
|
||||
@@ -338,9 +335,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
"execute",
|
||||
]
|
||||
|
||||
if (
|
||||
not request_attempted and self.sticky
|
||||
): # skip calculating if sticky and no new codes, reusing last step value
|
||||
# skip calculating if sticky and no new codes, reusing last step value
|
||||
if not request_attempted and self.sticky:
|
||||
return self.reward
|
||||
|
||||
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
|
||||
|
||||
@@ -75,8 +75,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
server_password: Optional[str] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
last_query_response: Optional[Dict] = None
|
||||
"""Keep track of the latest query response. Used to determine rewards."""
|
||||
_server_connection_id: Optional[str] = None
|
||||
"""Connection ID to the Database Server."""
|
||||
client_connections: Dict[str, DatabaseClientConnection] = {}
|
||||
|
||||
@@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.network.router import ACLAction,
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import ControlledAgent
|
||||
@@ -19,32 +20,30 @@ from tests.conftest import ControlledAgent
|
||||
|
||||
def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
"""Test that we get the right reward for failing to fetch a website."""
|
||||
# set up the scenario, configure the web browser to the correct url
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = WebpageUnavailablePenalty(node_hostname="client_1")
|
||||
|
||||
agent.reward_function.register_component(comp, 0.7)
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# client 1 has not attempted to fetch webpage yet!
|
||||
assert agent.reward_function.current_reward == 0.0
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
browser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage()
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
agent.reward_function.register_component(comp, 0.7)
|
||||
|
||||
# Check that before trying to fetch the webpage, the reward is 0.0
|
||||
agent.store_action(("DONOTHING", {}))
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.0
|
||||
|
||||
# Check that successfully fetching the webpage yields a reward of 0.7
|
||||
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.7
|
||||
|
||||
# Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7
|
||||
router: Router = game.simulation.network.get_node_by_hostname("router")
|
||||
router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP)
|
||||
assert not browser.get_webpage()
|
||||
agent.store_action(action)
|
||||
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == -0.7
|
||||
|
||||
@@ -70,32 +69,25 @@ def test_uc2_rewards(game_and_agent):
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
|
||||
|
||||
response = db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
|
||||
response = game.simulation.apply_request(request)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
),
|
||||
)
|
||||
assert reward_value == 1.0
|
||||
|
||||
router.acl.remove_rule(position=2)
|
||||
|
||||
db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
response = game.simulation.apply_request(request)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
|
||||
),
|
||||
)
|
||||
assert reward_value == -1.0
|
||||
|
||||
@@ -146,7 +146,6 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
|
||||
assert green_db_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
|
||||
data_manipulation_bot.port_scan_p_of_success = 1
|
||||
data_manipulation_bot.data_manipulation_p_of_success = 1
|
||||
@@ -155,4 +154,3 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
|
||||
assert green_db_connection.query("SELECT") is False
|
||||
assert green_db_client.last_query_response.get("status_code") != 200
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_ransomware_script_attack(ransomware_script_and_db_server):
|
||||
|
||||
|
||||
def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client):
|
||||
"""Test to see show that the database service still operate"""
|
||||
"""Test to show that the database service still operates after corruption"""
|
||||
network: Network = ransomware_script_db_server_green_client
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
@@ -111,17 +111,18 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
|
||||
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
green_db_client.connect()
|
||||
green_db_client_connection: DatabaseClientConnection = green_db_client.get_new_connection()
|
||||
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
|
||||
assert green_db_client_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
assert green_db_client.query("SELECT") is True
|
||||
|
||||
ransomware_script_application.attack()
|
||||
|
||||
network.apply_timestep(0)
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert green_db_client_connection.query("SELECT") is True
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
assert green_db_client.query("SELECT") is True # Still operates but now the data field of response is empty
|
||||
|
||||
Reference in New Issue
Block a user