#2748: Port of PrimAITE Internal changes.

This commit is contained in:
Nick Todd
2024-08-19 12:55:45 +01:00
parent c886d4b014
commit 2c71958c91
4 changed files with 63 additions and 22 deletions

View File

@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
main port they're assigned.
- Added reward calculation details to AgentHistoryItem.
### Changed
- File and folder observations can now be configured to always show the true health status, or require scanning like before.

View File

@@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel):
reward: Optional[float] = None
reward_info: Dict[str, Any] = {}
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""

View File

@@ -47,7 +47,15 @@ class AbstractReward:
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -67,7 +75,15 @@ class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
@classmethod
@@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward):
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
@@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
component will keep track of that information. In that case, it doesn't matter whether the last successful
request returned was able to connect to the database server, because there has been an unsuccessful request
since.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]:
self._last_request_failed = last_action_response.response.status != "success"
@@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
# If the last request was actually sent, then check if the connection was established.
db_state = access_from_nested_dict(state, self.location_in_state)
if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state:
_LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}")
last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"}
return 0.0
last_connection_successful = db_state["last_connection_successful"]
if last_connection_successful is False:
return -1.0
elif last_connection_successful is True:
return 1.0
return 0.0
last_action_response.reward_info = {"last_connection_successful": last_connection_successful}
return 1.0 if last_connection_successful else -1.0
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
@@ -346,7 +369,15 @@ class SharedReward(AbstractReward):
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
"""Simply access the other agent's reward and return it.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return self.callback(self.agent_name)
@classmethod
@@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward):
self.do_nothing_penalty = do_nothing_penalty
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied."""
"""Calculate the penalty to be applied.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
if last_action_response.action == "DONOTHING":
return self.do_nothing_penalty
else:

View File

@@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
ahi = AgentHistoryItem(
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={},
request=["execute"],
response=response,
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == 1.0
assert ahi.reward_info == {"last_connection_successful": True}
router.acl.remove_rule(position=2)
@@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent):
]
)
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == -1.0
assert ahi.reward_info == {"last_connection_successful": False}
def test_shared_reward():