From 2f3e40fb6b6abe943770119b109319bc0edb7266 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Feb 2024 13:22:05 +0000 Subject: [PATCH] Fix issue around reset --- src/primaite/game/game.py | 2 +- src/primaite/session/environment.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index eeb0d007..3b9a21d4 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -417,7 +417,7 @@ class PrimaiteGame: agent_settings=agent_settings, ) else: - msg(f"Configuration error: {agent_type} is not a valid agent type.") + msg = f"Configuration error: {agent_type} is not a valid agent type." _LOGGER.error(msg) raise ValueError(msg) game.agents[agent_cfg["ref"]] = new_agent diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index f8dbab9d..d54503a3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,3 +1,4 @@ +import copy import json from typing import Any, Dict, Optional, SupportsFloat, Tuple @@ -23,7 +24,7 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.game_config: Dict = game_config """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Current game.""" self._agent_name = next(iter(self.game.rl_agents)) """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" @@ -78,7 +79,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) - self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state()