Align step counts in logging
This commit is contained in:
@@ -47,6 +47,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
step = self.game.step_counter
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.game.apply_agent_actions()
|
||||
@@ -62,18 +63,18 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
|
||||
} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
self._write_step_metadata_json(step, action, state, reward)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
def _write_step_metadata_json(self, step: int, action: int, state: Dict, reward: int):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"step": step,
|
||||
"action": int(action),
|
||||
"reward": int(reward),
|
||||
"state": state,
|
||||
@@ -121,6 +122,12 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
else:
|
||||
return self.agent.observation_manager.current_observation
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
@@ -144,6 +151,10 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
@@ -211,6 +222,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
@@ -232,18 +244,18 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
@@ -260,3 +272,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
Reference in New Issue
Block a user