Minor fixes
This commit is contained in:
@@ -139,8 +139,12 @@ class PrimaiteGame:
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping. Step counter: {self.step_counter}")
|
||||
|
||||
if self.step_counter == 0:
|
||||
state = self.get_sim_state()
|
||||
for agent in self.agents.values():
|
||||
agent.update_observation(state=state)
|
||||
# Apply all actions to simulation as requests
|
||||
action_data = self.apply_agent_actions()
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
@@ -149,7 +153,7 @@ class PrimaiteGame:
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state, and the response from the last action
|
||||
self.update_agents(state=sim_state, action_data=action_data)
|
||||
self.update_agents(state=sim_state)
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
@@ -458,6 +462,7 @@ class PrimaiteGame:
|
||||
|
||||
# Set the NMNE capture config
|
||||
set_nmne_config(network_config.get("nmne_config", {}))
|
||||
game.update_agents(game.get_sim_state())
|
||||
|
||||
return game
|
||||
|
||||
|
||||
@@ -189,8 +189,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
|
||||
Reference in New Issue
Block a user