From 5cacbf03373bccd634c8086783222bb45648871a Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 2 Sep 2024 16:54:13 +0100 Subject: [PATCH] #2845: Changes to write observation space data to log file. --- src/primaite/session/environment.py | 25 +++++++++++++++++++++++++ src/primaite/session/io.py | 2 ++ 2 files changed, 27 insertions(+) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index c66663e3..23b86546 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -112,6 +112,9 @@ class PrimaiteGymEnv(gymnasium.Env): self.game.update_agents(state) next_obs = self._get_obs() # this doesn't update observation, just gets the current observation + if self.io.settings.obs_space_data: + # Write unflattened observation space to log file. + self._write_obs_space_data(self.agent.observation_manager.current_observation) reward = self.agent.reward_function.current_reward _LOGGER.debug(f"step: {self.game.step_counter}, Blue reward: {reward}") terminated = False @@ -139,6 +142,25 @@ class PrimaiteGymEnv(gymnasium.Env): with open(path, "w") as file: json.dump(data, file) + def _write_obs_space_data(self, obs_space: ObsType) -> None: + """Write the unflattened observation space data to a JSON file. + + :param obs: Observation of the environment (dict) + :type obs: ObsType + """ + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "obs_space_data" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{self.game.step_counter}.json" + + data = { + "episode": self.episode_counter, + "step": self.game.step_counter, + "obs_space_data": obs_space, + } + with open(path, "w") as file: + json.dump(data, file) + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" _LOGGER.info( @@ -159,6 +181,9 @@ class PrimaiteGymEnv(gymnasium.Env): state = self.game.get_sim_state() self.game.update_agents(state=state) next_obs = self._get_obs() + if self.io.settings.obs_space_data: + # Write unflattened observation space to log file. + self._write_obs_space_data(self.agent.observation_manager.current_observation) info = {} return next_obs, info diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 78d7cb3c..3627e9e9 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -45,6 +45,8 @@ class PrimaiteIO: """The level of sys logs that should be included in the logfiles/logged into terminal.""" agent_log_level: LogLevel = LogLevel.INFO """The level of agent logs that should be included in the logfiles/logged into terminal.""" + obs_space_data: bool = False + """Whether to save observation space data to a log file.""" def __init__(self, settings: Optional[Settings] = None) -> None: """