From 64b9ba3ecf2e1865e902917bd80d05bac70ab0bd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 20 Feb 2024 16:21:03 +0000 Subject: [PATCH] Make environment reset reinstantiate the game --- src/primaite/game/game.py | 5 +--- .../notebooks/training_example_sb3.ipynb | 4 +-- src/primaite/session/environment.py | 27 ++++++++++++------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1fd0dc8b..091438ce 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -67,9 +67,6 @@ class PrimaiteGame: self.step_counter: int = 0 """Current timestep within the episode.""" - self.episode_counter: int = 0 - """Current episode number.""" - self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" @@ -163,7 +160,7 @@ class PrimaiteGame: return True return False - def reset(self) -> None: + def reset(self) -> None: # TODO: deprecated - remove me """Reset the game, this will reset the simulation.""" self.episode_counter += 1 self.step_counter = 0 diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index e5085c5e..164142b2 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -38,7 +38,7 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game)" + "gym = PrimaiteGymEnv(game_config=cfg)" ] }, { @@ -65,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=1000)\n" + "model.learn(total_timesteps=10)\n" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a3831bc1..ad770f8f 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -18,11 +18,18 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame): + def __init__(self, game_config: Dict): """Initialise the environment.""" super().__init__() - self.game: "PrimaiteGame" = game + self.game_config: Dict = game_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + """Current game.""" self.agent: ProxyAgent = self.game.rl_agents[0] + """The agent within the game that is controlled by the RL algorithm.""" + + self.episode_counter: int = 0 + """Current episode number.""" def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -45,13 +52,13 @@ class PrimaiteGymEnv(gymnasium.Env): return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "action": int(action), "reward": int(reward), @@ -63,10 +70,12 @@ class PrimaiteGymEnv(gymnasium.Env): def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" print( - f"Resetting environment, episode {self.game.episode_counter}, " + f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" ) - self.game.reset() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.agent = self.game.rl_agents[0] + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -107,7 +116,7 @@ class PrimaiteRayEnv(gymnasium.Env): :type env_config: Dict[str, PrimaiteGame] """ self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) - self.env.game.episode_counter -= 1 + self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -194,13 +203,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "actions": {agent_name: int(action) for agent_name, action in actions.items()}, "reward": rewards,