Files
PrimAITE/src/primaite/game/agent/rewards.py
2024-07-16 09:32:26 +01:00

459 lines
18 KiB
Python

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""
Manages the reward function for the agent.
Each agent is equipped with a RewardFunction, which is made up of a list of reward components. The components are
designed to calculate a reward value based on the current state of the simulation. The overall reward function is a
weighed sum of the components.
The reward function is typically specified using a config yaml file or a config dictionary. The following example shows
the structure:
```yaml
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_name: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_name: web_server
service_ref: web_server_database_client
```
"""
from abc import abstractmethod
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing_extensions import Never
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@classmethod
@abstractmethod
def from_config(cls, config: dict) -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: dict
:return: The reward component.
:rtype: AbstractReward
"""
return cls()
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@classmethod
def from_config(cls, config: dict) -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
:type config: dict
:return: The reward component.
:rtype: DummyReward
"""
return cls()
class DatabaseFileIntegrity(AbstractReward):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
:param node_hostname: Hostname of the node which contains the database file.
:type node_hostname: str
:param folder_name: folder which contains the database file.
:type folder_name: str
:param file_name: name of the database file.
:type file_name: str
"""
self.location_in_state = [
"network",
"nodes",
node_hostname,
"file_system",
"folders",
folder_name,
"files",
file_name,
]
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:type state: Dict
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
_LOGGER.debug(
f"Could not calculate {self.__class__} reward because "
"simulation state did not contain enough information."
)
return 0.0
health_status = database_file_state["health_status"]
if health_status == 2:
return -1
elif health_status == 1:
return 1
else:
return 0
@classmethod
def from_config(cls, config: Dict) -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
node_hostname = config.get("node_hostname")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not (node_hostname and folder_name and file_name):
msg = f"{cls.__name__} could not be initialised with parameters {config}"
_LOGGER.error(msg)
raise ValueError(msg)
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_hostname: str, service_name: str) -> None:
"""Initialise the reward component.
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
"""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:type state: Dict
"""
web_service_state = access_from_nested_dict(state, self.location_in_state)
if web_service_state is NOT_PRESENT_IN_STATE:
return 0.0
most_recent_return_code = web_service_state["last_response_status_code"]
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
if most_recent_return_code == 200:
return 1.0
elif most_recent_return_code == 404:
return -1.0
else:
return 0.0
@classmethod
def from_config(cls, config: Dict) -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:return: The reward component.
:rtype: WebServer404Penalty
"""
node_hostname = config.get("node_hostname")
service_name = config.get("service_name")
if not (node_hostname and service_name):
msg = (
f"{cls.__name__} could not be initialised from config because node_name and service_ref were not "
"found in reward config."
)
_LOGGER.warning(msg)
raise ValueError(msg)
return cls(node_hostname=node_hostname, service_name=service_name)
class WebpageUnavailablePenalty(AbstractReward):
"""Penalises the agent when the web browser fails to fetch a webpage."""
def __init__(self, node_hostname: str) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
When the green agent requests to execute the browser application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last webpage
had a 200 status code, because there has been an unsuccessful request since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request did actually go through, then check if the webpage also loaded
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state:
_LOGGER.debug(
"Web browser reward could not be calculated because the web browser history on node",
f"{self._node} was not reported in the simulation state. Returning 0.0",
)
return 0.0 # 0 if the web browser cannot be found
if not web_browser_state["history"]:
return 0.0 # 0 if no requests have been attempted yet
outcome = web_browser_state["history"][-1]["outcome"]
if outcome == "PENDING":
return 0.0 # 0 if a request was attempted but not yet resolved
elif outcome == 200:
return 1.0 # 1 for successful request
else: # includes failure codes and SERVER_UNREACHABLE
return -1.0 # -1 for failure
@classmethod
def from_config(cls, config: dict) -> AbstractReward:
"""
Build the reward component object from config.
:param config: Configuration dictionary.
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str) -> None:
"""
Initialise the reward component.
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
When the green agent requests to execute the database client application, and that request fails, this reward
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
# if agent couldn't even get as far as sending the request (because for example the node was off), then
# apply a penalty
if self._last_request_failed:
return -1.0
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
elif last_connection_successful is True:
return 1.0
return 0.0
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
"""
Build the reward component object from config.
:param config: Configuration dictionary.
:type config: Dict
"""
node_hostname = config.get("node_hostname")
return cls(node_hostname=node_hostname)
class SharedReward(AbstractReward):
"""Adds another agent's reward to the overall reward."""
def __init__(self, agent_name: Optional[str] = None) -> None:
"""
Initialise the shared reward.
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
correctly.
:param agent_name: The name whose reward is an input
:type agent_name: Optional[str]
"""
self.agent_name = agent_name
"""Agent whose reward to track."""
def default_callback(agent_name: str) -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
return self.callback(self.agent_name)
@classmethod
def from_config(cls, config: Dict) -> "SharedReward":
"""
Build the SharedReward object from config.
:param config: Configuration dictionary
:type config: Dict
"""
agent_name = config.get("agent_name")
return cls(agent_name=agent_name)
class ActionPenalty(AbstractReward):
"""Apply a negative reward when taking any action except DONOTHING."""
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
"""
Initialise the reward.
Reward or penalise agents for doing nothing or taking actions.
:param action_penalty: Reward to give agents for taking any action except DONOTHING
:type action_penalty: float
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
:type do_nothing_penalty: float
"""
self.action_penalty = action_penalty
self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied."""
if last_action_response.action == "DONOTHING":
return self.do_nothing_penalty
else:
return self.action_penalty
@classmethod
def from_config(cls, config: Dict) -> "ActionPenalty":
"""Build the ActionPenalty object from config."""
action_penalty = config.get("action_penalty", -1.0)
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
class RewardFunction:
"""Manages the reward function for the agent."""
rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
"ACTION_PENALTY": ActionPenalty,
}
"""List of reward class identifiers."""
def __init__(self):
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float = 0.0
self.total_reward: float = 0.0
def register_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.
:param component: Instance of a reward component.
:type component: AbstractReward
:param weight: Relative weight of the reward component, defaults to 1.0
:type weight: float, optional
"""
self.reward_components.append((component, weight))
def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
:type state: Dict
"""
total = 0.0
for comp_and_weight in self.reward_components:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state, last_action_response=last_action_response)
self.current_reward = total
return self.current_reward
@classmethod
def from_config(cls, config: Dict) -> "RewardFunction":
"""Create a reward function from a config dictionary.
:param config: dict of options for the reward manager's constructor
:type config: Dict
:return: The reward manager.
:rtype: RewardFunction
"""
new = cls()
for rew_component_cfg in config["reward_components"]:
rew_type = rew_component_cfg["type"]
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}))
new.register_component(component=rew_instance, weight=weight)
return new