From afa775baff03a7754ff0818934563f9587a2a1cb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 15:52:34 +0000 Subject: [PATCH] Add test for new reward --- .../game_layer/test_rewards.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index fd8a89a4..53753967 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,7 +1,10 @@ -from primaite.game.agent.rewards import WebpageUnavailablePenalty +from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty +from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database.database_service import DatabaseService from tests.conftest import ControlledAgent @@ -35,3 +38,44 @@ def test_WebpageUnavailablePenalty(game_and_agent): agent.store_action(action) game.step() assert agent.reward_function.current_reward == -0.7 + + +def test_uc2_rewards(game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + server_1.software_manager.install(DatabaseService) + db_service = server_1.software_manager.software.get("DatabaseService") + db_service.start() + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) + db_client.run() + + router: Router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + + comp = GreenAdminDatabaseUnreachablePenalty("client_1") + + db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate(state) + assert reward_value == 1.0 + + router.acl.remove_rule(position=2) + + db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate(state) + assert reward_value == -1.0