From 9b3699389acf18613b41de01524cd65d883c07a4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Apr 2024 15:36:59 +0100 Subject: [PATCH] #2523 - Adding in some additional logging for other agent classes. This currently prints total_reward instead of average reward --- src/primaite/session/environment.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a34ebf04..6c42c701 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -35,7 +35,6 @@ class PrimaiteGymEnv(gymnasium.Env): """Current game.""" self._agent_name = next(iter(self.game.rl_agents)) """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" - self.episode_counter: int = 0 """Current episode number.""" @@ -49,8 +48,8 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen by the RL policy step = self.game.step_counter self.agent.store_action(action) - # apply_agent_actions accesses the action we just stored self.game.pre_timestep() + # apply_agent_actions accesses the action we just stored self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() @@ -205,9 +204,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} + _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") + if self.io.settings.save_agent_actions: all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + self.episode_counter += 1 self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) self.game.setup_for_episode(episode=self.episode_counter) @@ -245,6 +248,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 4. Get rewards rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} + _LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}") terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} infos = {name: {} for name, _ in self.agents.items()}