Files
PrimAITE/tests/integration_tests/game_layer/test_rewards.py

122 lines
4.6 KiB
Python
Raw Normal View History

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
2024-03-13 14:01:17 +00:00
import yaml
2024-05-31 15:00:18 +01:00
from primaite.game.agent.interface import AgentHistoryItem
2024-03-03 15:52:34 +00:00
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
2024-03-13 14:01:17 +00:00
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
2024-03-03 15:52:34 +00:00
from primaite.simulator.network.hardware.nodes.host.server import Server
2024-02-08 16:15:57 +00:00
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
2024-02-06 15:05:44 +00:00
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
2024-03-03 15:52:34 +00:00
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
2024-03-13 14:01:17 +00:00
from tests import TEST_ASSETS_ROOT
2024-02-06 15:05:44 +00:00
from tests.conftest import ControlledAgent
def test_WebpageUnavailablePenalty(game_and_agent):
"""Test that we get the right reward for failing to fetch a website."""
game, agent = game_and_agent
agent: ControlledAgent
comp = WebpageUnavailablePenalty(node_hostname="client_1")
agent.reward_function.register_component(comp, 0.7)
action = ("DONOTHING", {})
agent.store_action(action)
game.step()
# client 1 has not attempted to fetch webpage yet!
assert agent.reward_function.current_reward == 0.0
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
assert browser.get_webpage()
action = ("DONOTHING", {})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == 0.7
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP)
assert not browser.get_webpage()
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.7
2024-03-03 15:52:34 +00:00
def test_uc2_rewards(game_and_agent):
2024-03-04 09:58:57 +00:00
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
2024-03-03 15:52:34 +00:00
game, agent = game_and_agent
agent: ControlledAgent
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
server_1.software_manager.install(DatabaseService)
db_service = server_1.software_manager.software.get("DatabaseService")
db_service.start()
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
db_client.run()
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2)
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
2024-03-14 14:33:04 +00:00
response = db_client.apply_request(
2024-03-03 15:52:34 +00:00
[
"execute",
]
)
state = game.get_sim_state()
2024-03-14 14:33:04 +00:00
reward_value = comp.calculate(
state,
2024-05-31 15:00:18 +01:00
last_action_response=AgentHistoryItem(
2024-03-14 14:33:04 +00:00
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
2024-03-03 15:52:34 +00:00
assert reward_value == 1.0
router.acl.remove_rule(position=2)
db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
2024-03-14 14:33:04 +00:00
reward_value = comp.calculate(
state,
2024-05-31 15:00:18 +01:00
last_action_response=AgentHistoryItem(
2024-03-14 14:33:04 +00:00
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
2024-03-03 15:52:34 +00:00
assert reward_value == -1.0
2024-03-13 14:01:17 +00:00
def test_shared_reward():
CFG_PATH = TEST_ASSETS_ROOT / "configs/shared_rewards.yaml"
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
2024-04-25 15:09:46 +01:00
env = PrimaiteGymEnv(env_config=cfg)
2024-03-13 14:01:17 +00:00
env.reset()
order = env.game._reward_calculation_order
assert order.index("defender") > order.index("client_1_green_user")
assert order.index("defender") > order.index("client_2_green_user")
for step in range(256):
act = env.action_space.sample()
env.step(act)
g1_reward = env.game.agents["client_1_green_user"].reward_function.current_reward
g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward
blue_reward = env.game.agents["defender"].reward_function.current_reward
assert blue_reward == g1_reward + g2_reward