Add test for new reward

This commit is contained in:
Marek Wolan
2024-03-03 15:52:34 +00:00
parent 4d51b1a414
commit afa775baff

View File

@@ -1,7 +1,10 @@
from primaite.game.agent.rewards import WebpageUnavailablePenalty
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
from tests.conftest import ControlledAgent
@@ -35,3 +38,44 @@ def test_WebpageUnavailablePenalty(game_and_agent):
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.7
def test_uc2_rewards(game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
server_1.software_manager.install(DatabaseService)
db_service = server_1.software_manager.software.get("DatabaseService")
db_service.start()
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
db_client.run()
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2)
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
assert reward_value == 1.0
router.acl.remove_rule(position=2)
db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
reward_value = comp.calculate(state)
assert reward_value == -1.0