From b81dd26b713f82d422b09bc666fc046626437760 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 13:12:08 +0000 Subject: [PATCH] Add Ray env class --- src/primaite/game/environment.py | 24 ++++++++-- ...agent.ipynb => training_example_sb3.ipynb} | 47 +++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) rename src/primaite/notebooks/{train_rllib_single_agent.ipynb => training_example_sb3.ipynb} (68%) diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py index 57846b99..d540bd02 100644 --- a/src/primaite/game/environment.py +++ b/src/primaite/game/environment.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, SupportsFloat, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType @@ -15,11 +15,11 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame, agents: List[ProxyAgent]): + def __init__(self, game: PrimaiteGame): """Initialise the environment.""" super().__init__() self.game: "PrimaiteGame" = game - self.agent: ProxyAgent = agents[0] + self.agent: ProxyAgent = self.game.rl_agents[0] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" @@ -63,3 +63,21 @@ class PrimaiteGymEnv(gymnasium.Env): unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) + + +class PrimaiteRayEnv(gymnasium.Env): + """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" + + def __init__(self, env_config: Dict) -> None: + """Initialise the environment.""" + self.env = PrimaiteGymEnv(game=env_config["game"]) + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + return self.env.reset(seed=seed) + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: + """Perform a step in the environment.""" + return self.env.step(action) diff --git a/src/primaite/notebooks/train_rllib_single_agent.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb similarity index 68% rename from src/primaite/notebooks/train_rllib_single_agent.ipynb rename to src/primaite/notebooks/training_example_sb3.ipynb index 709e6e6f..e4033a79 100644 --- a/src/primaite/notebooks/train_rllib_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -55,6 +55,53 @@ "source": [ "gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)" ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3 import PPO" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "model = PPO('MlpPolicy', gym)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.learn(total_timesteps=1000)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "model.save(\"deleteme\")" + ] } ], "metadata": {