#2656 - Make action penalty more configurable
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
@@ -130,56 +132,66 @@ def test_action_penalty_loads_from_config():
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
env.reset()
|
||||
|
||||
ActionPenalty_Value = env.game.agents["defender"].reward_function.reward_components[2][0].penalty
|
||||
CFG_Penalty_Value = cfg["agents"][3]["reward_function"]["reward_components"][2]["options"]["penalty_value"]
|
||||
|
||||
assert ActionPenalty_Value == CFG_Penalty_Value
|
||||
defender = env.game.agents["defender"]
|
||||
act_penalty_obj = None
|
||||
for comp in defender.reward_function.reward_components:
|
||||
if isinstance(comp[0], ActionPenalty):
|
||||
act_penalty_obj = comp[0]
|
||||
if act_penalty_obj is None:
|
||||
pytest.fail("Action penalty reward component was not added to the agent from config.")
|
||||
assert act_penalty_obj.action_penalty == -0.75
|
||||
assert act_penalty_obj.do_nothing_penalty == 0.125
|
||||
|
||||
|
||||
def test_action_penalty(game_and_agent):
|
||||
def test_action_penalty():
|
||||
"""Test that the action penalty is correctly applied when agent performs any action"""
|
||||
|
||||
# Create an ActionPenalty Reward
|
||||
Penalty = ActionPenalty(agent_name="Test_Blue_Agent", penalty=-1.0)
|
||||
|
||||
game, _ = game_and_agent
|
||||
|
||||
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
|
||||
server_1.software_manager.install(DatabaseService)
|
||||
db_service = server_1.software_manager.software.get("DatabaseService")
|
||||
db_service.start()
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
|
||||
db_client.run()
|
||||
|
||||
response = db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
|
||||
state = game.get_sim_state()
|
||||
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
# Assert that penalty is applied if action isn't DONOTHING
|
||||
reward_value = Penalty.calculate(
|
||||
state,
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
timestep=0,
|
||||
action="NODE_APPLICATION_EXECUTE",
|
||||
parameters={"node_id": 0, "application_id": 1},
|
||||
request=["execute"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == -1.0
|
||||
assert reward_value == -0.75
|
||||
|
||||
# Assert that no penalty applied for a DONOTHING action
|
||||
reward_value = Penalty.calculate(
|
||||
state,
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="DONOTHING", parameters={}, request=["execute"], response=response
|
||||
timestep=0,
|
||||
action="DONOTHING",
|
||||
parameters={},
|
||||
request=["do_nothing"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == 0
|
||||
assert reward_value == 0.125
|
||||
|
||||
|
||||
def test_action_penalty_e2e(game_and_agent):
|
||||
"""Test that we get the right reward for doing actions to fetch a website."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
agent.reward_function.register_component(comp, 1.0)
|
||||
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.125
|
||||
|
||||
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == -0.75
|
||||
|
||||
Reference in New Issue
Block a user