Change get_action signature for agents

This commit is contained in:
Marek Wolan
2024-03-04 10:43:38 +00:00
parent d1480e4477
commit ac9d550e9b
7 changed files with 27 additions and 28 deletions

View File

@@ -1,5 +1,5 @@
import random
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple
from gymnasium.core import ObsType
@@ -26,14 +26,14 @@ class DataManipulationAgent(AbstractScriptedAgent):
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]:
"""Waits until a specific timestep, then attempts to execute its data manipulation application.
:param obs: _description_
:param obs: Current observation for this agent, not used in DataManipulationAgent
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:param timestep: The current simulation timestep, used for scheduling actions
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:

View File

@@ -112,7 +112,7 @@ class AbstractAgent(ABC):
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -152,14 +152,14 @@ class AbstractScriptedAgent(AbstractAgent):
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Sample the action space randomly.
:param obs: _description_
:param obs: Current observation for this agent, not used in RandomAgent
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:param timestep: The current simulation timestep, not used in RandomAgent
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.action_manager.space.sample())
@@ -185,14 +185,14 @@ class ProxyAgent(AbstractAgent):
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
:type obs: ObsType
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
:type reward: float, optional
:param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface.
:type timestep: int
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""

View File

@@ -270,7 +270,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
return -1.0
elif last_connection_successful is True:
return 1.0
return 0
return 0.0
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:

View File

@@ -70,17 +70,17 @@ class ProbabilisticAgent(AbstractScriptedAgent):
super().__init__(agent_name, action_space, observation_space, reward_function)
def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""
Choose a random action from the action space.
Sample the action space randomly.
The probability of each action is given by the corresponding index in ``self.probabilities``.
:param obs: Current observation of the simulation
:param obs: Current observation for this agent, not used in ProbabilisticAgent
:type obs: ObsType
:param reward: Reward for the last step, not used for scripted agents, defaults to 0
:type reward: float, optional
:return: Action to be taken in CAOS format.
:param timestep: The current simulation timestep, not used in ProbabilisticAgent
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)

View File

@@ -165,8 +165,7 @@ class PrimaiteGame:
agent_actions = {}
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter)
action_choice, options = agent.get_action(obs, timestep=self.step_counter)
agent_actions[agent.agent_name] = (action_choice, options)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)

View File

@@ -328,7 +328,7 @@ class ControlledAgent(AbstractAgent):
)
self.most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]:
def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
return self.most_recent_action

View File

@@ -69,7 +69,7 @@ def test_probabilistic_agent():
node_application_execute_count = 0
node_file_delete_count = 0
for _ in range(N_TRIALS):
a = pa.get_action(0, timestep=0)
a = pa.get_action(0)
if a == ("DONOTHING", {}):
do_nothing_count += 1
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):