From daa34385e550dc43e984706faa2df024e74d1ad7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 9 Jan 2024 14:53:15 +0000 Subject: [PATCH] Add agent reset for episodes --- src/primaite/game/agent/data_manipulation_bot.py | 9 +++++++++ src/primaite/game/agent/interface.py | 4 ++++ src/primaite/game/game.py | 1 + 3 files changed, 14 insertions(+) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 8237ce06..3b558087 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -15,6 +15,7 @@ class DataManipulationAgent(AbstractScriptedAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + print("red start step: ", self.agent_settings.start_settings.start_step) self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) @@ -27,6 +28,7 @@ class DataManipulationAgent(AbstractScriptedAgent): -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance ) self.next_execution_timestep = timestep + random_timestep_increment + print("next execution red step: ", self.next_execution_timestep) def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. @@ -41,8 +43,15 @@ class DataManipulationAgent(AbstractScriptedAgent): current_timestep = self.action_manager.game.step_counter if current_timestep < self.next_execution_timestep: + print("red agent doing nothing") return "DONOTHING", {"dummy": 0} self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) + print("red agent doing an execute") return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} + + def reset_agent_for_episode(self) -> None: + """Set the next execution timestep when the episode resets.""" + super().reset_agent_for_episode() + self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 8657fc45..8b6dd6d4 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -135,6 +135,10 @@ class AbstractAgent(ABC): request = self.action_manager.form_request(action_identifier=action, action_options=options) return request + def reset_agent_for_episode(self) -> None: + """Agent reset logic should go here.""" + pass + class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 586bca79..08098754 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -162,6 +162,7 @@ class PrimaiteGame: self.simulation.reset_component_for_episode(episode=self.episode_counter) for agent in self.agents: agent.reward_function.total_reward = 0.0 + agent.reset_agent_for_episode() def close(self) -> None: """Close the game, this will close the simulation."""