diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index b5de9a5a..c758c926 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -1,5 +1,5 @@ import random -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple from gymnasium.core import ObsType @@ -26,14 +26,14 @@ class DataManipulationAgent(AbstractScriptedAgent): ) self.next_execution_timestep = timestep + random_timestep_increment - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. + def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: + """Waits until a specific timestep, then attempts to execute its data manipulation application. - :param obs: _description_ + :param obs: Current observation for this agent, not used in DataManipulationAgent :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ + :param timestep: The current simulation timestep, used for scheduling actions + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 4f434bad..88848479 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -112,7 +112,7 @@ class AbstractAgent(ABC): return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -152,14 +152,14 @@ class AbstractScriptedAgent(AbstractAgent): class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Sample the action space randomly. - :param obs: _description_ + :param obs: Current observation for this agent, not used in RandomAgent :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ + :param timestep: The current simulation timestep, not used in RandomAgent + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ return self.action_manager.get_action(self.action_manager.space.sample()) @@ -185,14 +185,14 @@ class ProxyAgent(AbstractAgent): self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. :type obs: ObsType - :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. - :type reward: float, optional + :param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface. + :type timestep: int :return: Action to be taken in CAOS format. :rtype: Tuple[str, Dict] """ diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 882ad024..8c8e36ad 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -270,7 +270,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return -1.0 elif last_connection_successful is True: return 1.0 - return 0 + return 0.0 @classmethod def from_config(cls, config: Dict) -> AbstractReward: diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents.py index 28d94062..5111df32 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents.py @@ -70,17 +70,17 @@ class ProbabilisticAgent(AbstractScriptedAgent): super().__init__(agent_name, action_space, observation_space, reward_function) - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ - Choose a random action from the action space. + Sample the action space randomly. The probability of each action is given by the corresponding index in ``self.probabilities``. - :param obs: Current observation of the simulation + :param obs: Current observation for this agent, not used in ProbabilisticAgent :type obs: ObsType - :param reward: Reward for the last step, not used for scripted agents, defaults to 0 - :type reward: float, optional - :return: Action to be taken in CAOS format. + :param timestep: The current simulation timestep, not used in ProbabilisticAgent + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 0749e5db..cd88d832 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -165,8 +165,7 @@ class PrimaiteGame: agent_actions = {} for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation - rew = agent.reward_function.current_reward - action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter) + action_choice, options = agent.get_action(obs, timestep=self.step_counter) agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) diff --git a/tests/conftest.py b/tests/conftest.py index b60de730..a117a1ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -328,7 +328,7 @@ class ControlledAgent(AbstractAgent): ) self.most_recent_action: Tuple[str, Dict] - def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" return self.most_recent_action diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index f0b37cac..73228e36 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -69,7 +69,7 @@ def test_probabilistic_agent(): node_application_execute_count = 0 node_file_delete_count = 0 for _ in range(N_TRIALS): - a = pa.get_action(0, timestep=0) + a = pa.get_action(0) if a == ("DONOTHING", {}): do_nothing_count += 1 elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):