Files

204 lines
8.3 KiB
Python

# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
import pytest
import yaml
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
from tests import TEST_ASSETS_ROOT
from tests.conftest import ControlledAgent
def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that we get the right reward for failing to fetch a website."""
# set up the scenario, configure the web browser to the correct url
game, agent = game_and_agent
agent: ControlledAgent
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
comp = WebpageUnavailablePenalty(config=schema)
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("web-browser")
browser.run()
browser.config.target_url = "http://www.example.com"
agent.reward_function.register_component(comp, 0.7)
# Check that before trying to fetch the webpage, the reward is 0.0
agent.store_action(("do-nothing", {}))
game.step()
assert agent.reward_function.current_reward == 0.0
# Check that successfully fetching the webpage yields a reward of 0.7
agent.store_action(("node-application-execute", {"node_name": "client_1", "application_name": "web-browser"}))
game.step()
assert agent.reward_function.current_reward == 0.7
# Block the web traffic, check that failing to fetch the webpage yields a reward of -0.7
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(
action=ACLAction.DENY,
protocol=PROTOCOL_LOOKUP["TCP"],
src_port=PORT_LOOKUP["HTTP"],
dst_port=PORT_LOOKUP["HTTP"],
)
agent.store_action(("node-application-execute", {"node_name": "client_1", "application_name": "web-browser"}))
game.step()
assert agent.reward_function.current_reward == -0.7
def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
game, agent = game_and_agent
agent: ControlledAgent
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
server_1.software_manager.install(DatabaseService)
db_service = server_1.software_manager.software.get("database-service")
db_service.start()
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software.get("database-client")
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
db_client.run()
router: Router = game.simulation.network.get_node_by_hostname("router")
router.acl.add_rule(
ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2
)
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
comp = GreenAdminDatabaseUnreachablePenalty(config=schema)
request = ["network", "node", "client_1", "application", "database-client", "execute"]
response = game.simulation.apply_request(request)
state = game.get_sim_state()
ahi = AgentHistoryItem(
timestep=0, action="node-application-execute", parameters={}, request=request, response=response
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == 1.0
assert ahi.reward_info == {"connection_attempt_status": "success"}
router.acl.remove_rule(position=2)
response = game.simulation.apply_request(request)
state = game.get_sim_state()
ahi = AgentHistoryItem(
timestep=0, action="node-application-execute", parameters={}, request=request, response=response
)
reward_value = comp.calculate(
state,
last_action_response=ahi,
)
assert reward_value == -1.0
assert ahi.reward_info == {"connection_attempt_status": "failure"}
def test_shared_reward():
CFG_PATH = TEST_ASSETS_ROOT / "configs/shared_rewards.yaml"
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
env.reset()
order = env.game._reward_calculation_order
assert order.index("defender") > order.index("client_1_green_user")
assert order.index("defender") > order.index("client_2_green_user")
for step in range(256):
act = env.action_space.sample()
env.step(act)
g1_reward = env.game.agents["client_1_green_user"].reward_function.current_reward
g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward
blue_reward = env.game.agents["defender"].reward_function.current_reward
assert blue_reward == g1_reward + g2_reward
def test_action_penalty_loads_from_config():
"""Test to ensure that action penalty is correctly loaded from config into PrimaiteGymEnv"""
CFG_PATH = TEST_ASSETS_ROOT / "configs/action_penalty.yaml"
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
env.reset()
defender = env.game.agents["defender"]
act_penalty_obj = None
for comp in defender.reward_function.reward_components:
if isinstance(comp[0], ActionPenalty):
act_penalty_obj = comp[0]
if act_penalty_obj is None:
pytest.fail("Action penalty reward component was not added to the agent from config.")
assert act_penalty_obj.config.action_penalty == -0.75
assert act_penalty_obj.config.do_nothing_penalty == 0.125
def test_action_penalty():
"""Test that the action penalty is correctly applied when agent performs any action"""
# Create an ActionPenalty Reward
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
# Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
Penalty = ActionPenalty(config=schema)
# Assert that penalty is applied if action isn't do-nothing
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
timestep=0,
action="node-application-execute",
parameters={"node_name": "client", "application_name": "web-browser"},
request=["execute"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == -0.75
# Assert that no penalty applied for a do-nothing action
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
timestep=0,
action="do-nothing",
parameters={},
request=["do-nothing"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == 0.125
def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that we get the right reward for doing actions to fetch a website."""
game, agent = game_and_agent
agent: ControlledAgent
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
comp = ActionPenalty(config=schema)
agent.reward_function.register_component(comp, 1.0)
action = ("do-nothing", {})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == 0.125
action = ("node-file-scan", {"node_name": "client", "folder_name": "downloads", "file_name": "document.pdf"})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.75