From 7a680678aa4e69355f1c2a11bf2c8157f2bae321 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 27 Jun 2024 12:01:32 +0100 Subject: [PATCH] #2656 - Make action penalty more configurable --- src/primaite/game/agent/rewards.py | 28 ++-- tests/assets/configs/action_penalty.yaml | 141 +----------------- .../game_layer/test_rewards.py | 80 +++++----- 3 files changed, 62 insertions(+), 187 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index a0736bb0..4a17e9a5 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -363,33 +363,33 @@ class SharedReward(AbstractReward): class ActionPenalty(AbstractReward): """Apply a negative reward when taking any action except DONOTHING.""" - def __init__(self, agent_name: str, penalty: float): + def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: """ Initialise the reward. - This negative reward should be applied when the agent in training chooses to take any - action that isn't DONOTHING. + Reward or penalise agents for doing nothing or taking actions. + + :param action_penalty: Reward to give agents for taking any action except DONOTHING + :type action_penalty: float + :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action + :type do_nothing_penalty: float """ - self.agent_name = agent_name - self.penalty = penalty + self.action_penalty = action_penalty + self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied.""" if last_action_response.action == "DONOTHING": - # No penalty for doing nothing at present - return 0 + return self.do_nothing_penalty else: - _LOGGER.info( - f"Blue Agent has incurred a penalty of {self.penalty}, for action: {last_action_response.action}" - ) - return self.penalty + return self.action_penalty @classmethod def from_config(cls, config: Dict) -> "ActionPenalty": """Build the ActionPenalty object from config.""" - agent_name = config.get("agent_name") - penalty_value = config.get("penalty_value", 0) # default to 0. - return cls(agent_name=agent_name, penalty=penalty_value) + action_penalty = config.get("action_penalty", -1.0) + do_nothing_penalty = config.get("do_nothing_penalty", 0.0) + return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) class RewardFunction: diff --git a/tests/assets/configs/action_penalty.yaml b/tests/assets/configs/action_penalty.yaml index 4eb562fe..1771ba5f 100644 --- a/tests/assets/configs/action_penalty.yaml +++ b/tests/assets/configs/action_penalty.yaml @@ -21,135 +21,6 @@ game: low: 0 agents: - - ref: client_2_green_user - team: GREEN - type: ProbabilisticAgent - agent_settings: - action_probabilities: - 0: 0.3 - 1: 0.6 - 2: 0.1 - observation_space: null - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_2 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 - action_map: - 0: - action: DONOTHING - options: {} - 1: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 0 - 2: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 1 - - reward_function: - reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 - options: - node_hostname: client_2 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_2 - - - ref: client_1_green_user - team: GREEN - type: ProbabilisticAgent - agent_settings: - action_probabilities: - 0: 0.3 - 1: 0.6 - 2: 0.1 - observation_space: null - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: WebBrowser - - application_name: DatabaseClient - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - max_applications_per_node: 2 - action_map: - 0: - action: DONOTHING - options: {} - 1: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 0 - 2: - action: NODE_APPLICATION_EXECUTE - options: - node_id: 0 - application_id: 1 - - reward_function: - reward_components: - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.25 - options: - node_hostname: client_1 - - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY - weight: 0.05 - options: - node_hostname: client_1 - - - ref: data_manipulation_attacker - team: RED - type: RedDatabaseCorruptingAgent - - observation_space: null - - action_space: - action_list: - - type: DONOTHING - - type: NODE_APPLICATION_EXECUTE - options: - nodes: - - node_name: client_1 - applications: - - application_name: DataManipulationBot - - node_name: client_2 - applications: - - application_name: DataManipulationBot - max_folders_per_node: 1 - max_files_per_folder: 1 - max_services_per_node: 1 - - reward_function: - reward_components: - - type: DUMMY - - agent_settings: # options specific to this particular agent type, basically args of __init__(self) - start_settings: - start_step: 25 - frequency: 20 - variance: 5 - ref: defender team: BLUE @@ -712,19 +583,11 @@ agents: reward_function: reward_components: - - type: SHARED_REWARD - weight: 1.0 - options: - agent_name: client_1_green_user - - type: SHARED_REWARD - weight: 1.0 - options: - agent_name: client_2_green_user - type: ACTION_PENALTY weight: 1.0 options: - agent_name: defender - penalty_value: -1 + action_penalty: -0.75 + do_nothing_penalty: 0.125 agent_settings: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 95e70271..2bf551c8 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,9 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest import yaml from primaite.game.agent.interface import AgentHistoryItem from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame +from primaite.interface.request import RequestResponse from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router @@ -130,56 +132,66 @@ def test_action_penalty_loads_from_config(): env = PrimaiteGymEnv(env_config=cfg) env.reset() - - ActionPenalty_Value = env.game.agents["defender"].reward_function.reward_components[2][0].penalty - CFG_Penalty_Value = cfg["agents"][3]["reward_function"]["reward_components"][2]["options"]["penalty_value"] - - assert ActionPenalty_Value == CFG_Penalty_Value + defender = env.game.agents["defender"] + act_penalty_obj = None + for comp in defender.reward_function.reward_components: + if isinstance(comp[0], ActionPenalty): + act_penalty_obj = comp[0] + if act_penalty_obj is None: + pytest.fail("Action penalty reward component was not added to the agent from config.") + assert act_penalty_obj.action_penalty == -0.75 + assert act_penalty_obj.do_nothing_penalty == 0.125 -def test_action_penalty(game_and_agent): +def test_action_penalty(): """Test that the action penalty is correctly applied when agent performs any action""" # Create an ActionPenalty Reward - Penalty = ActionPenalty(agent_name="Test_Blue_Agent", penalty=-1.0) - - game, _ = game_and_agent - - server_1: Server = game.simulation.network.get_node_by_hostname("server_1") - server_1.software_manager.install(DatabaseService) - db_service = server_1.software_manager.software.get("DatabaseService") - db_service.start() - - client_1 = game.simulation.network.get_node_by_hostname("client_1") - client_1.software_manager.install(DatabaseClient) - db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") - db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) - db_client.run() - - response = db_client.apply_request( - [ - "execute", - ] - ) - - state = game.get_sim_state() + Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) # Assert that penalty is applied if action isn't DONOTHING reward_value = Penalty.calculate( - state, + state={}, last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + timestep=0, + action="NODE_APPLICATION_EXECUTE", + parameters={"node_id": 0, "application_id": 1}, + request=["execute"], + response=RequestResponse.from_bool(True), ), ) - assert reward_value == -1.0 + assert reward_value == -0.75 # Assert that no penalty applied for a DONOTHING action reward_value = Penalty.calculate( - state, + state={}, last_action_response=AgentHistoryItem( - timestep=0, action="DONOTHING", parameters={}, request=["execute"], response=response + timestep=0, + action="DONOTHING", + parameters={}, + request=["do_nothing"], + response=RequestResponse.from_bool(True), ), ) - assert reward_value == 0 + assert reward_value == 0.125 + + +def test_action_penalty_e2e(game_and_agent): + """Test that we get the right reward for doing actions to fetch a website.""" + game, agent = game_and_agent + agent: ControlledAgent + comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + + agent.reward_function.register_component(comp, 1.0) + + action = ("DONOTHING", {}) + agent.store_action(action) + game.step() + assert agent.reward_function.current_reward == 0.125 + + action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0}) + agent.store_action(action) + game.step() + assert agent.reward_function.current_reward == -0.75