#2656 - Make action penalty more configurable

This commit is contained in:
Marek Wolan
2024-06-27 12:01:32 +01:00
parent e204afff6f
commit 7a680678aa
3 changed files with 62 additions and 187 deletions

View File

@@ -363,33 +363,33 @@ class SharedReward(AbstractReward):
class ActionPenalty(AbstractReward):
"""Apply a negative reward when taking any action except DONOTHING."""
def __init__(self, agent_name: str, penalty: float):
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
"""
Initialise the reward.
This negative reward should be applied when the agent in training chooses to take any
action that isn't DONOTHING.
Reward or penalise agents for doing nothing or taking actions.
:param action_penalty: Reward to give agents for taking any action except DONOTHING
:type action_penalty: float
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
:type do_nothing_penalty: float
"""
self.agent_name = agent_name
self.penalty = penalty
self.action_penalty = action_penalty
self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied."""
if last_action_response.action == "DONOTHING":
# No penalty for doing nothing at present
return 0
return self.do_nothing_penalty
else:
_LOGGER.info(
f"Blue Agent has incurred a penalty of {self.penalty}, for action: {last_action_response.action}"
)
return self.penalty
return self.action_penalty
@classmethod
def from_config(cls, config: Dict) -> "ActionPenalty":
"""Build the ActionPenalty object from config."""
agent_name = config.get("agent_name")
penalty_value = config.get("penalty_value", 0) # default to 0.
return cls(agent_name=agent_name, penalty=penalty_value)
action_penalty = config.get("action_penalty", -1.0)
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
class RewardFunction:

View File

@@ -21,135 +21,6 @@ game:
low: 0
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_2
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_2
- ref: client_1_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: WEBPAGE_UNAVAILABLE_PENALTY
weight: 0.25
options:
node_hostname: client_1
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 0.05
options:
node_hostname: client_1
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client_1
applications:
- application_name: DataManipulationBot
- node_name: client_2
applications:
- application_name: DataManipulationBot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -712,19 +583,11 @@ agents:
reward_function:
reward_components:
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_1_green_user
- type: SHARED_REWARD
weight: 1.0
options:
agent_name: client_2_green_user
- type: ACTION_PENALTY
weight: 1.0
options:
agent_name: defender
penalty_value: -1
action_penalty: -0.75
do_nothing_penalty: 0.125
agent_settings:

View File

@@ -1,9 +1,11 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
import yaml
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.interface.request import RequestResponse
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
@@ -130,56 +132,66 @@ def test_action_penalty_loads_from_config():
env = PrimaiteGymEnv(env_config=cfg)
env.reset()
ActionPenalty_Value = env.game.agents["defender"].reward_function.reward_components[2][0].penalty
CFG_Penalty_Value = cfg["agents"][3]["reward_function"]["reward_components"][2]["options"]["penalty_value"]
assert ActionPenalty_Value == CFG_Penalty_Value
defender = env.game.agents["defender"]
act_penalty_obj = None
for comp in defender.reward_function.reward_components:
if isinstance(comp[0], ActionPenalty):
act_penalty_obj = comp[0]
if act_penalty_obj is None:
pytest.fail("Action penalty reward component was not added to the agent from config.")
assert act_penalty_obj.action_penalty == -0.75
assert act_penalty_obj.do_nothing_penalty == 0.125
def test_action_penalty(game_and_agent):
def test_action_penalty():
"""Test that the action penalty is correctly applied when agent performs any action"""
# Create an ActionPenalty Reward
Penalty = ActionPenalty(agent_name="Test_Blue_Agent", penalty=-1.0)
game, _ = game_and_agent
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
server_1.software_manager.install(DatabaseService)
db_service = server_1.software_manager.software.get("DatabaseService")
db_service.start()
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DatabaseClient)
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
db_client.run()
response = db_client.apply_request(
[
"execute",
]
)
state = game.get_sim_state()
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
# Assert that penalty is applied if action isn't DONOTHING
reward_value = Penalty.calculate(
state,
state={},
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={"node_id": 0, "application_id": 1},
request=["execute"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == -1.0
assert reward_value == -0.75
# Assert that no penalty applied for a DONOTHING action
reward_value = Penalty.calculate(
state,
state={},
last_action_response=AgentHistoryItem(
timestep=0, action="DONOTHING", parameters={}, request=["execute"], response=response
timestep=0,
action="DONOTHING",
parameters={},
request=["do_nothing"],
response=RequestResponse.from_bool(True),
),
)
assert reward_value == 0
assert reward_value == 0.125
def test_action_penalty_e2e(game_and_agent):
"""Test that we get the right reward for doing actions to fetch a website."""
game, agent = game_and_agent
agent: ControlledAgent
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
agent.reward_function.register_component(comp, 1.0)
action = ("DONOTHING", {})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == 0.125
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.75