#2656 - Make action penalty more configurable
This commit is contained in:
@@ -363,33 +363,33 @@ class SharedReward(AbstractReward):
|
||||
class ActionPenalty(AbstractReward):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
|
||||
def __init__(self, agent_name: str, penalty: float):
|
||||
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
|
||||
"""
|
||||
Initialise the reward.
|
||||
|
||||
This negative reward should be applied when the agent in training chooses to take any
|
||||
action that isn't DONOTHING.
|
||||
Reward or penalise agents for doing nothing or taking actions.
|
||||
|
||||
:param action_penalty: Reward to give agents for taking any action except DONOTHING
|
||||
:type action_penalty: float
|
||||
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
|
||||
:type do_nothing_penalty: float
|
||||
"""
|
||||
self.agent_name = agent_name
|
||||
self.penalty = penalty
|
||||
self.action_penalty = action_penalty
|
||||
self.do_nothing_penalty = do_nothing_penalty
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the penalty to be applied."""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
# No penalty for doing nothing at present
|
||||
return 0
|
||||
return self.do_nothing_penalty
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"Blue Agent has incurred a penalty of {self.penalty}, for action: {last_action_response.action}"
|
||||
)
|
||||
return self.penalty
|
||||
return self.action_penalty
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "ActionPenalty":
|
||||
"""Build the ActionPenalty object from config."""
|
||||
agent_name = config.get("agent_name")
|
||||
penalty_value = config.get("penalty_value", 0) # default to 0.
|
||||
return cls(agent_name=agent_name, penalty=penalty_value)
|
||||
action_penalty = config.get("action_penalty", -1.0)
|
||||
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
|
||||
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
|
||||
@@ -21,135 +21,6 @@ game:
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_2
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
- application_name: DatabaseClient
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_applications_per_node: 2
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
2:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: WEBPAGE_UNAVAILABLE_PENALTY
|
||||
weight: 0.25
|
||||
options:
|
||||
node_hostname: client_1
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 0.05
|
||||
options:
|
||||
node_hostname: client_1
|
||||
|
||||
- ref: data_manipulation_attacker
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
- node_name: client_2
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
@@ -712,19 +583,11 @@ agents:
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_1_green_user
|
||||
- type: SHARED_REWARD
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: client_2_green_user
|
||||
- type: ACTION_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
agent_name: defender
|
||||
penalty_value: -1
|
||||
action_penalty: -0.75
|
||||
do_nothing_penalty: 0.125
|
||||
|
||||
|
||||
agent_settings:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
@@ -130,56 +132,66 @@ def test_action_penalty_loads_from_config():
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
env.reset()
|
||||
|
||||
ActionPenalty_Value = env.game.agents["defender"].reward_function.reward_components[2][0].penalty
|
||||
CFG_Penalty_Value = cfg["agents"][3]["reward_function"]["reward_components"][2]["options"]["penalty_value"]
|
||||
|
||||
assert ActionPenalty_Value == CFG_Penalty_Value
|
||||
defender = env.game.agents["defender"]
|
||||
act_penalty_obj = None
|
||||
for comp in defender.reward_function.reward_components:
|
||||
if isinstance(comp[0], ActionPenalty):
|
||||
act_penalty_obj = comp[0]
|
||||
if act_penalty_obj is None:
|
||||
pytest.fail("Action penalty reward component was not added to the agent from config.")
|
||||
assert act_penalty_obj.action_penalty == -0.75
|
||||
assert act_penalty_obj.do_nothing_penalty == 0.125
|
||||
|
||||
|
||||
def test_action_penalty(game_and_agent):
|
||||
def test_action_penalty():
|
||||
"""Test that the action penalty is correctly applied when agent performs any action"""
|
||||
|
||||
# Create an ActionPenalty Reward
|
||||
Penalty = ActionPenalty(agent_name="Test_Blue_Agent", penalty=-1.0)
|
||||
|
||||
game, _ = game_and_agent
|
||||
|
||||
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
|
||||
server_1.software_manager.install(DatabaseService)
|
||||
db_service = server_1.software_manager.software.get("DatabaseService")
|
||||
db_service.start()
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
db_client.configure(server_ip_address=server_1.network_interface[1].ip_address)
|
||||
db_client.run()
|
||||
|
||||
response = db_client.apply_request(
|
||||
[
|
||||
"execute",
|
||||
]
|
||||
)
|
||||
|
||||
state = game.get_sim_state()
|
||||
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
# Assert that penalty is applied if action isn't DONOTHING
|
||||
reward_value = Penalty.calculate(
|
||||
state,
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
timestep=0,
|
||||
action="NODE_APPLICATION_EXECUTE",
|
||||
parameters={"node_id": 0, "application_id": 1},
|
||||
request=["execute"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == -1.0
|
||||
assert reward_value == -0.75
|
||||
|
||||
# Assert that no penalty applied for a DONOTHING action
|
||||
reward_value = Penalty.calculate(
|
||||
state,
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="DONOTHING", parameters={}, request=["execute"], response=response
|
||||
timestep=0,
|
||||
action="DONOTHING",
|
||||
parameters={},
|
||||
request=["do_nothing"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == 0
|
||||
assert reward_value == 0.125
|
||||
|
||||
|
||||
def test_action_penalty_e2e(game_and_agent):
|
||||
"""Test that we get the right reward for doing actions to fetch a website."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
agent.reward_function.register_component(comp, 1.0)
|
||||
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.125
|
||||
|
||||
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == -0.75
|
||||
|
||||
Reference in New Issue
Block a user