#2845: Changes to write observation space data to log file.
This commit is contained in:
@@ -112,6 +112,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.game.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
|
||||
if self.io.settings.obs_space_data:
|
||||
# Write unflattened observation space to log file.
|
||||
self._write_obs_space_data(self.agent.observation_manager.current_observation)
|
||||
reward = self.agent.reward_function.current_reward
|
||||
_LOGGER.debug(f"step: {self.game.step_counter}, Blue reward: {reward}")
|
||||
terminated = False
|
||||
@@ -139,6 +142,25 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _write_obs_space_data(self, obs_space: ObsType) -> None:
|
||||
"""Write the unflattened observation space data to a JSON file.
|
||||
|
||||
:param obs: Observation of the environment (dict)
|
||||
:type obs: ObsType
|
||||
"""
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "obs_space_data"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"obs_space_data": obs_space,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
_LOGGER.info(
|
||||
@@ -159,6 +181,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state=state)
|
||||
next_obs = self._get_obs()
|
||||
if self.io.settings.obs_space_data:
|
||||
# Write unflattened observation space to log file.
|
||||
self._write_obs_space_data(self.agent.observation_manager.current_observation)
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ class PrimaiteIO:
|
||||
"""The level of sys logs that should be included in the logfiles/logged into terminal."""
|
||||
agent_log_level: LogLevel = LogLevel.INFO
|
||||
"""The level of agent logs that should be included in the logfiles/logged into terminal."""
|
||||
obs_space_data: bool = False
|
||||
"""Whether to save observation space data to a log file."""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user