diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ac61df4..207d156a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added actions to change users' passwords. - Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the main port they're assigned. +- Added reward calculation details to AgentHistoryItem. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index f57dc191..14b97821 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel): reward: Optional[float] = None + reward_info: Dict[str, Any] = {} + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index c959ee5b..b913501d 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -47,7 +47,15 @@ class AbstractReward: @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -67,7 +75,15 @@ class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward): def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. - :param state: The current state of the simulation. + :param state: Current simulation state :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: @@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last successful request returned was able to connect to the database server, because there has been an unsuccessful request since. + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: self._last_request_failed = last_action_response.response.status != "success" @@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): # If the last request was actually sent, then check if the connection was established. db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: - _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"} return 0.0 last_connection_successful = db_state["last_connection_successful"] - if last_connection_successful is False: - return -1.0 - elif last_connection_successful is True: - return 1.0 - return 0.0 + last_action_response.reward_info = {"last_connection_successful": last_connection_successful} + return 1.0 if last_connection_successful else -1.0 @classmethod def from_config(cls, config: Dict) -> AbstractReward: @@ -346,7 +369,15 @@ class SharedReward(AbstractReward): """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Simply access the other agent's reward and return it.""" + """Simply access the other agent's reward and return it. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return self.callback(self.agent_name) @classmethod @@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward): self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the penalty to be applied.""" + """Calculate the penalty to be applied. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ if last_action_response.action == "DONOTHING": return self.do_nothing_penalty else: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 2bf551c8..e945f482 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), + ahi = AgentHistoryItem( + timestep=0, + action="NODE_APPLICATION_EXECUTE", + parameters={}, + request=["execute"], + response=response, ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == 1.0 + assert ahi.reward_info == {"last_connection_successful": True} router.acl.remove_rule(position=2) @@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), - ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == -1.0 + assert ahi.reward_info == {"last_connection_successful": False} def test_shared_reward():