diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 14b97821..aac6c05a 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -38,6 +38,9 @@ class AgentHistoryItem(BaseModel): reward_info: Dict[str, Any] = {} + obs_space_data: Optional[ObsType] = None + """The observation space data for this step.""" + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -169,12 +172,23 @@ class AbstractAgent(ABC): return request def process_action_response( - self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + self, + timestep: int, + action: str, + parameters: Dict[str, Any], + request: RequestFormat, + response: RequestResponse, + obs_space_data: ObsType, ) -> None: """Process the response from the most recent action.""" self.history.append( AgentHistoryItem( - timestep=timestep, action=action, parameters=parameters, request=request, response=response + timestep=timestep, + action=action, + parameters=parameters, + request=request, + response=response, + obs_space_data=obs_space_data, ) ) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 045b2467..ed3c84d3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -186,6 +186,7 @@ class PrimaiteGame: parameters=parameters, request=request, response=response, + obs_space_data=obs, ) def pre_timestep(self) -> None: diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 23b86546..c66663e3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -112,9 +112,6 @@ class PrimaiteGymEnv(gymnasium.Env): self.game.update_agents(state) next_obs = self._get_obs() # this doesn't update observation, just gets the current observation - if self.io.settings.obs_space_data: - # Write unflattened observation space to log file. - self._write_obs_space_data(self.agent.observation_manager.current_observation) reward = self.agent.reward_function.current_reward _LOGGER.debug(f"step: {self.game.step_counter}, Blue reward: {reward}") terminated = False @@ -142,25 +139,6 @@ class PrimaiteGymEnv(gymnasium.Env): with open(path, "w") as file: json.dump(data, file) - def _write_obs_space_data(self, obs_space: ObsType) -> None: - """Write the unflattened observation space data to a JSON file. - - :param obs: Observation of the environment (dict) - :type obs: ObsType - """ - output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "obs_space_data" - - output_dir.mkdir(parents=True, exist_ok=True) - path = output_dir / f"step_{self.game.step_counter}.json" - - data = { - "episode": self.episode_counter, - "step": self.game.step_counter, - "obs_space_data": obs_space, - } - with open(path, "w") as file: - json.dump(data, file) - def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" _LOGGER.info( @@ -181,9 +159,6 @@ class PrimaiteGymEnv(gymnasium.Env): state = self.game.get_sim_state() self.game.update_agents(state=state) next_obs = self._get_obs() - if self.io.settings.obs_space_data: - # Write unflattened observation space to log file. - self._write_obs_space_data(self.agent.observation_manager.current_observation) info = {} return next_obs, info diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 3627e9e9..78d7cb3c 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -45,8 +45,6 @@ class PrimaiteIO: """The level of sys logs that should be included in the logfiles/logged into terminal.""" agent_log_level: LogLevel = LogLevel.INFO """The level of agent logs that should be included in the logfiles/logged into terminal.""" - obs_space_data: bool = False - """Whether to save observation space data to a log file.""" def __init__(self, settings: Optional[Settings] = None) -> None: """