#2523 - Adding in some additional logging for other agent classes. This currently prints total_reward instead of average reward

This commit is contained in:
Charlie Crane
2024-04-30 15:36:59 +01:00
parent b8c46a92e9
commit 9b3699389a

View File

@@ -35,7 +35,6 @@ class PrimaiteGymEnv(gymnasium.Env):
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
self.episode_counter: int = 0
"""Current episode number."""
@@ -49,8 +48,8 @@ class PrimaiteGymEnv(gymnasium.Env):
# make ProxyAgent store the action chosen by the RL policy
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.pre_timestep()
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
@@ -205,9 +204,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
@@ -245,6 +248,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {name: {} for name, _ in self.agents.items()}