#2656 - Make action penalty more configurable

This commit is contained in:
Marek Wolan
2024-06-27 12:01:32 +01:00
parent e204afff6f
commit 7a680678aa
3 changed files with 62 additions and 187 deletions

View File

@@ -1,9 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
import yaml
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
@@ -130,56 +132,66 @@ def test_action_penalty_loads_from_config():
env = PrimaiteGymEnv(env_config=cfg)
env.reset()
ActionPenalty_Value = env.game.agents["defender"].reward_function.reward_components[2][0].penalty
CFG_Penalty_Value = cfg["agents"][3]["reward_function"]["reward_components"][2]["options"]["penalty_value"]
assert ActionPenalty_Value == CFG_Penalty_Value
defender = env.game.agents["defender"]
act_penalty_obj = None
for comp in defender.reward_function.reward_components:
if isinstance(comp[0], ActionPenalty):
act_penalty_obj = comp[0]
if act_penalty_obj is None:
pytest.fail("Action penalty reward component was not added to the agent from config.")
assert act_penalty_obj.action_penalty == -0.75
assert act_penalty_obj.do_nothing_penalty == 0.125
def test_action_penalty(game_and_agent):
def test_action_penalty():
"""Test that the action penalty is correctly applied when agent performs any action"""
# Create an ActionPenalty Reward
Penalty = ActionPenalty(agent_name="Test_Blue_Agent", penalty=-1.0)
game, _ = game_and_agent
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
server_1.software_manager.install(DatabaseService)
db_service = server_1.software_manager.software.get("DatabaseService")
db_service.start()
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
db_client.run()
response = db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
# Assert that penalty is applied if action isn't DONOTHING
reward_value = Penalty.calculate(
state,
state={},
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={"node_id": 0, "application_id": 1},
request=["execute"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == -1.0
assert reward_value == -0.75
# Assert that no penalty applied for a DONOTHING action
reward_value = Penalty.calculate(
state,
state={},
last_action_response=AgentHistoryItem(
timestep=0, action="DONOTHING", parameters={}, request=["execute"], response=response
timestep=0,
action="DONOTHING",
parameters={},
request=["do_nothing"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == 0
assert reward_value == 0.125
def test_action_penalty_e2e(game_and_agent):
"""Test that we get the right reward for doing actions to fetch a website."""
game, agent = game_and_agent
agent: ControlledAgent
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
agent.reward_function.register_component(comp, 1.0)
action = ("DONOTHING", {})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == 0.125
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.75