Merged PR 507: #2748: Port of PrimAITE Internal changes.
## Summary This a port of the AgentHistoryItem DB Admin (GreenAdminDatabaseUnreachablePenalty reward) changes that were made to the PrimAITE Internal repo. See also #2826. ## Test process Updated tests/integration_tests/game_layer/test_rewards.py. ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [X] updated the **change log** - [X] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code #2748: Port of PrimAITE Internal changes. Related work items: #2748
This commit is contained in:
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added actions to change users' passwords.
|
||||
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
|
||||
main port they're assigned.
|
||||
- Added reward calculation details to AgentHistoryItem.
|
||||
|
||||
### Changed
|
||||
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
|
||||
|
||||
@@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel):
|
||||
|
||||
reward: Optional[float] = None
|
||||
|
||||
reward_info: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
|
||||
@@ -47,7 +47,15 @@ class AbstractReward:
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@@ -67,7 +75,15 @@ class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
component will keep track of that information. In that case, it doesn't matter whether the last successful
|
||||
request returned was able to connect to the database server, because there has been an unsuccessful request
|
||||
since.
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
|
||||
self._last_request_failed = last_action_response.response.status != "success"
|
||||
@@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
# If the last request was actually sent, then check if the connection was established.
|
||||
db_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
|
||||
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
|
||||
last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"}
|
||||
return 0.0
|
||||
last_connection_successful = db_state["last_connection_successful"]
|
||||
if last_connection_successful is False:
|
||||
return -1.0
|
||||
elif last_connection_successful is True:
|
||||
return 1.0
|
||||
return 0.0
|
||||
last_action_response.reward_info = {"last_connection_successful": last_connection_successful}
|
||||
return 1.0 if last_connection_successful else -1.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractReward:
|
||||
@@ -346,7 +369,15 @@ class SharedReward(AbstractReward):
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
"""Simply access the other agent's reward and return it.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@classmethod
|
||||
@@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward):
|
||||
self.do_nothing_penalty = do_nothing_penalty
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the penalty to be applied."""
|
||||
"""Calculate the penalty to be applied.
|
||||
|
||||
:param state: Current simulation state
|
||||
:type state: Dict
|
||||
:param last_action_response: Current agent history state
|
||||
:type last_action_response: AgentHistoryItem state
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
return self.do_nothing_penalty
|
||||
else:
|
||||
|
||||
@@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent):
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
ahi = AgentHistoryItem(
|
||||
timestep=0,
|
||||
action="NODE_APPLICATION_EXECUTE",
|
||||
parameters={},
|
||||
request=["execute"],
|
||||
response=response,
|
||||
)
|
||||
reward_value = comp.calculate(state, last_action_response=ahi)
|
||||
assert reward_value == 1.0
|
||||
assert ahi.reward_info == {"last_connection_successful": True}
|
||||
|
||||
router.acl.remove_rule(position=2)
|
||||
|
||||
@@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent):
|
||||
]
|
||||
)
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
reward_value = comp.calculate(state, last_action_response=ahi)
|
||||
assert reward_value == -1.0
|
||||
assert ahi.reward_info == {"last_connection_successful": False}
|
||||
|
||||
|
||||
def test_shared_reward():
|
||||
|
||||
Reference in New Issue
Block a user