Merged PR 507: #2748: Port of PrimAITE Internal changes.

## Summary
This a port of the AgentHistoryItem DB Admin (GreenAdminDatabaseUnreachablePenalty reward) changes that were made to the PrimAITE Internal repo.
See also #2826.

## Test process
Updated tests/integration_tests/game_layer/test_rewards.py.

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [X] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

#2748: Port of PrimAITE Internal changes.

Related work items: #2748
This commit is contained in:
Nick Todd
2024-08-19 16:09:52 +00:00
4 changed files with 63 additions and 22 deletions

View File

@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added actions to change users' passwords.
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
main port they're assigned.
- Added reward calculation details to AgentHistoryItem.
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.

View File

@@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel):
reward: Optional[float] = None
reward_info: Dict[str, Any] = {}
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""

View File

@@ -47,7 +47,15 @@ class AbstractReward:
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -67,7 +75,15 @@ class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward):
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
@@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
@@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"}
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
elif last_connection_successful is True:
return 1.0
return 0.0
last_action_response.reward_info = {"last_connection_successful": last_connection_successful}
return 1.0 if last_connection_successful else -1.0
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
@@ -346,7 +369,15 @@ class SharedReward(AbstractReward):
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
"""Simply access the other agent's reward and return it.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return self.callback(self.agent_name)
@classmethod
@@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward):
self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied."""
"""Calculate the penalty to be applied.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.action == "DONOTHING":
return self.do_nothing_penalty
else:

View File

@@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
ahi = AgentHistoryItem(
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={},
request=["execute"],
response=response,
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == 1.0
assert ahi.reward_info == {"last_connection_successful": True}
router.acl.remove_rule(position=2)
@@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == -1.0
assert ahi.reward_info == {"last_connection_successful": False}
def test_shared_reward():