Align step counts in logging

This commit is contained in:
Marek Wolan
2024-03-15 16:17:38 +00:00
parent a8c98c28f1
commit b4d310eda2

View File

@@ -47,6 +47,7 @@ class PrimaiteGymEnv(gymnasium.Env):
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.game.apply_agent_actions()
@@ -62,18 +63,18 @@ class PrimaiteGymEnv(gymnasium.Env):
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
self._write_step_metadata_json(step, action, state, reward)
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
def _write_step_metadata_json(self, step: int, action: int, state: Dict, reward: int):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": self.game.step_counter,
"step": step,
"action": int(action),
"reward": int(reward),
"state": state,
@@ -121,6 +122,12 @@ class PrimaiteGymEnv(gymnasium.Env):
else:
return self.agent.observation_manager.current_observation
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
@@ -144,6 +151,10 @@ class PrimaiteRayEnv(gymnasium.Env):
"""Perform a step in the environment."""
return self.env.step(action)
def close(self):
"""Close the simulation."""
self.env.close()
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
@@ -211,6 +222,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
step = self.game.step_counter
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
@@ -232,18 +244,18 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
self._write_step_metadata_json(step, actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": self.game.step_counter,
"step": step,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
@@ -260,3 +272,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)