diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 00374791..8ac3956c 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -255,29 +255,26 @@ class WebpageUnavailablePenalty(AbstractReward): "execute", ] - if ( - not request_attempted and self.sticky - ): # skip calculating if sticky and no new codes, reusing last step value + # skip calculating if sticky and no new codes, reusing last step value + if not request_attempted and self.sticky: return self.reward if last_action_response.response.status != "success": self.reward = -1.0 - # - elif web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: + elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", f"{self._node} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 - elif not web_browser_state["history"]: - self.reward = 0.0 # 0 if no requests have been attempted yet - outcome = web_browser_state["history"][-1]["outcome"] - if outcome == "PENDING": - self.reward = 0.0 # 0 if a request was attempted but not yet resolved - elif outcome == 200: - self.reward = 1.0 # 1 for successful request - else: # includes failure codes and SERVER_UNREACHABLE - self.reward = -1.0 # -1 for failure + else: + outcome = web_browser_state["history"][-1]["outcome"] + if outcome == "PENDING": + self.reward = 0.0 # 0 if a request was attempted but not yet resolved + elif outcome == 200: + self.reward = 1.0 # 1 for successful request + else: # includes failure codes and SERVER_UNREACHABLE + self.reward = -1.0 # -1 for failure return self.reward @@ -325,7 +322,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): db_state = access_from_nested_dict(state, self.location_in_state) # If the last request was actually sent, then check if the connection was established. - if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: + if db_state is NOT_PRESENT_IN_STATE: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") self.reward = 0.0 @@ -338,9 +335,8 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): "execute", ] - if ( - not request_attempted and self.sticky - ): # skip calculating if sticky and no new codes, reusing last step value + # skip calculating if sticky and no new codes, reusing last step value + if not request_attempted and self.sticky: return self.reward self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 933afadf..0a626c00 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -75,8 +75,6 @@ class DatabaseClient(Application, identifier="DatabaseClient"): server_password: Optional[str] = None _query_success_tracker: Dict[str, bool] = {} """Keep track of connections that were established or verified during this step. Used for rewards.""" - last_query_response: Optional[Dict] = None - """Keep track of the latest query response. Used to determine rewards.""" _server_connection_id: Optional[str] = None """Connection ID to the Database Server.""" client_connections: Dict[str, DatabaseClientConnection] = {} diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 2bf551c8..83b04832 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -12,6 +12,7 @@ from primaite.simulator.network.hardware.nodes.network.router import ACLAction, from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -19,32 +20,30 @@ from tests.conftest import ControlledAgent def test_WebpageUnavailablePenalty(game_and_agent): """Test that we get the right reward for failing to fetch a website.""" + # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent agent: ControlledAgent comp = WebpageUnavailablePenalty(node_hostname="client_1") - - agent.reward_function.register_component(comp, 0.7) - action = ("DONOTHING", {}) - agent.store_action(action) - game.step() - - # client 1 has not attempted to fetch webpage yet! - assert agent.reward_function.current_reward == 0.0 - client_1 = game.simulation.network.get_node_by_hostname("client_1") - browser = client_1.software_manager.software.get("WebBrowser") + browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() browser.target_url = "http://www.example.com" - assert browser.get_webpage() - action = ("DONOTHING", {}) - agent.store_action(action) + agent.reward_function.register_component(comp, 0.7) + + # Check that before trying to fetch the webpage, the reward is 0.0 + agent.store_action(("DONOTHING", {})) + game.step() + assert agent.reward_function.current_reward == 0.0 + + # Check that successfully fetching the webpage yields a reward of 0.7 + agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == 0.7 + # Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7 router: Router = game.simulation.network.get_node_by_hostname("router") router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP) - assert not browser.get_webpage() - agent.store_action(action) + agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})) game.step() assert agent.reward_function.current_reward == -0.7 @@ -70,32 +69,25 @@ def test_uc2_rewards(game_and_agent): comp = GreenAdminDatabaseUnreachablePenalty("client_1") - response = db_client.apply_request( - [ - "execute", - ] - ) + request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] + response = game.simulation.apply_request(request) state = game.get_sim_state() reward_value = comp.calculate( state, last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response ), ) assert reward_value == 1.0 router.acl.remove_rule(position=2) - db_client.apply_request( - [ - "execute", - ] - ) + response = game.simulation.apply_request(request) state = game.get_sim_state() reward_value = comp.calculate( state, last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response ), ) assert reward_value == -1.0 diff --git a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py index a01cffbe..2e87578d 100644 --- a/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py +++ b/tests/integration_tests/system/red_applications/test_data_manipulation_bot_and_server.py @@ -146,7 +146,6 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_ assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD assert green_db_connection.query("SELECT") - assert green_db_client.last_query_response.get("status_code") == 200 data_manipulation_bot.port_scan_p_of_success = 1 data_manipulation_bot.data_manipulation_p_of_success = 1 @@ -155,4 +154,3 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_ assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED assert green_db_connection.query("SELECT") is False - assert green_db_client.last_query_response.get("status_code") != 200 diff --git a/tests/integration_tests/system/red_applications/test_ransomware_script.py b/tests/integration_tests/system/red_applications/test_ransomware_script.py index 2e3a0b1c..97abafb5 100644 --- a/tests/integration_tests/system/red_applications/test_ransomware_script.py +++ b/tests/integration_tests/system/red_applications/test_ransomware_script.py @@ -103,7 +103,7 @@ def test_ransomware_script_attack(ransomware_script_and_db_server): def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_green_client): - """Test to see show that the database service still operate""" + """Test to show that the database service still operates after corruption""" network: Network = ransomware_script_db_server_green_client client_1: Computer = network.get_node_by_hostname("client_1") @@ -111,17 +111,18 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_ client_2: Computer = network.get_node_by_hostname("client_2") green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient") + green_db_client.connect() green_db_client_connection: DatabaseClientConnection = green_db_client.get_new_connection() server: Server = network.get_node_by_hostname("server_1") db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService") assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD - assert green_db_client_connection.query("SELECT") - assert green_db_client.last_query_response.get("status_code") == 200 + assert green_db_client.query("SELECT") is True ransomware_script_application.attack() + network.apply_timestep(0) + assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT - assert green_db_client_connection.query("SELECT") is True - assert green_db_client.last_query_response.get("status_code") == 200 + assert green_db_client.query("SELECT") is True # Still operates but now the data field of response is empty