#2523 - Adding in some additional logging for other agent classes. This currently prints total_reward instead of average reward
This commit is contained in:
@@ -35,7 +35,6 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Current game."""
|
||||
self._agent_name = next(iter(self.game.rl_agents))
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
@@ -49,8 +48,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
# make ProxyAgent store the action chosen by the RL policy
|
||||
step = self.game.step_counter
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.pre_timestep()
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
@@ -205,9 +204,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
@@ -245,6 +248,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
|
||||
Reference in New Issue
Block a user