diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 3409100e..38e9d5fc 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,12 +1,10 @@ """PrimAITE game - Encapsulates the simulation and agents.""" -import json -import os from ipaddress import IPv4Address from typing import Dict, List from pydantic import BaseModel, ConfigDict -from primaite import getLogger, PRIMAITE_PATHS +from primaite import getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent @@ -109,13 +107,6 @@ class PrimaiteGame: # Get the current state of the simulation sim_state = self.get_sim_state() - # Create state suitable for dumping to JSON file. - dump_state = {self.episode_counter: {self.step_counter: sim_state}} - # Dump to file - if os.path.isfile(PRIMAITE_PATHS.episode_steps_log_file_path): - with open(PRIMAITE_PATHS.episode_steps_log_file_path, "a") as f: - json.dump(dump_state, f) - # Update agents' observations and rewards based on the current state self.update_agents(sim_state) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a5fdade9..913038f9 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,9 +1,11 @@ +import os from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType from ray.rllib.env.multi_agent_env import MultiAgentEnv +from primaite import PRIMAITE_PATHS from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame @@ -30,6 +32,17 @@ class PrimaiteGymEnv(gymnasium.Env): self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() + + # Create state suitable for dumping to file. + dump_state = {self.game.episode_counter: {self.game.step_counter: state}} + + # Dump to file + if os.path.isfile(PRIMAITE_PATHS.episode_steps_log_file_path): + with open(PRIMAITE_PATHS.episode_steps_log_file_path, "a", encoding="utf-8") as f: + f.write(str(dump_state)) + f.write("\n=================\n") + f.flush() + self.game.update_agents(state) next_obs = self._get_obs()