Make environment reset reinstantiate the game
This commit is contained in:
@@ -67,9 +67,6 @@ class PrimaiteGame:
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.options: PrimaiteGameOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
@@ -163,7 +160,7 @@ class PrimaiteGame:
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self) -> None: # TODO: deprecated - remove me
|
||||
"""Reset the game, this will reset the simulation."""
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym = PrimaiteGymEnv(game=game)"
|
||||
"gym = PrimaiteGymEnv(game_config=cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -65,7 +65,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=1000)\n"
|
||||
"model.learn(total_timesteps=10)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -18,11 +18,18 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
def __init__(self, game_config: Dict):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.game: "PrimaiteGame" = game
|
||||
self.game_config: Dict = game_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
|
||||
"""Current game."""
|
||||
self.agent: ProxyAgent = self.game.rl_agents[0]
|
||||
"""The agent within the game that is controlled by the RL algorithm."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
@@ -45,13 +52,13 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"action": int(action),
|
||||
"reward": int(reward),
|
||||
@@ -63,10 +70,12 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
print(
|
||||
f"Resetting environment, episode {self.game.episode_counter}, "
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
|
||||
)
|
||||
self.game.reset()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
|
||||
self.agent = self.game.rl_agents[0]
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
@@ -107,7 +116,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
:type env_config: Dict[str, PrimaiteGame]
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
|
||||
self.env.game.episode_counter -= 1
|
||||
self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
@@ -194,13 +203,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"episode": self.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
|
||||
Reference in New Issue
Block a user