#2845: Changed to store obs data within AgentHistoryItem

This commit is contained in:
Nick Todd
2024-09-03 14:38:19 +01:00
parent 5cacbf0337
commit 8e57e707b3
4 changed files with 17 additions and 29 deletions

View File

@@ -38,6 +38,9 @@ class AgentHistoryItem(BaseModel):
reward_info: Dict[str, Any] = {}
obs_space_data: Optional[ObsType] = None
"""The observation space data for this step."""
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -169,12 +172,23 @@ class AbstractAgent(ABC):
return request
def process_action_response(
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
self,
timestep: int,
action: str,
parameters: Dict[str, Any],
request: RequestFormat,
response: RequestResponse,
obs_space_data: ObsType,
) -> None:
"""Process the response from the most recent action."""
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
timestep=timestep,
action=action,
parameters=parameters,
request=request,
response=response,
obs_space_data=obs_space_data,
)
)

View File

@@ -186,6 +186,7 @@ class PrimaiteGame:
parameters=parameters,
request=request,
response=response,
obs_space_data=obs,
)
def pre_timestep(self) -> None:

View File

@@ -112,9 +112,6 @@ class PrimaiteGymEnv(gymnasium.Env):
self.game.update_agents(state)
next_obs = self._get_obs() # this doesn't update observation, just gets the current observation
if self.io.settings.obs_space_data:
# Write unflattened observation space to log file.
self._write_obs_space_data(self.agent.observation_manager.current_observation)
reward = self.agent.reward_function.current_reward
_LOGGER.debug(f"step: {self.game.step_counter}, Blue reward: {reward}")
terminated = False
@@ -142,25 +139,6 @@ class PrimaiteGymEnv(gymnasium.Env):
with open(path, "w") as file:
json.dump(data, file)
def _write_obs_space_data(self, obs_space: ObsType) -> None:
"""Write the unflattened observation space data to a JSON file.
:param obs: Observation of the environment (dict)
:type obs: ObsType
"""
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "obs_space_data"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.episode_counter,
"step": self.game.step_counter,
"obs_space_data": obs_space,
}
with open(path, "w") as file:
json.dump(data, file)
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
_LOGGER.info(
@@ -181,9 +159,6 @@ class PrimaiteGymEnv(gymnasium.Env):
state = self.game.get_sim_state()
self.game.update_agents(state=state)
next_obs = self._get_obs()
if self.io.settings.obs_space_data:
# Write unflattened observation space to log file.
self._write_obs_space_data(self.agent.observation_manager.current_observation)
info = {}
return next_obs, info

View File

@@ -45,8 +45,6 @@ class PrimaiteIO:
"""The level of sys logs that should be included in the logfiles/logged into terminal."""
agent_log_level: LogLevel = LogLevel.INFO
"""The level of agent logs that should be included in the logfiles/logged into terminal."""
obs_space_data: bool = False
"""Whether to save observation space data to a log file."""
def __init__(self, settings: Optional[Settings] = None) -> None:
"""