Add Ray env class

This commit is contained in:
Marek Wolan
2023-11-22 13:12:08 +00:00
parent 1138644a4b
commit b81dd26b71
2 changed files with 68 additions and 3 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple
from typing import Any, Dict, Optional, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -15,11 +15,11 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game: PrimaiteGame, agents: List[ProxyAgent]):
def __init__(self, game: PrimaiteGame):
"""Initialise the environment."""
super().__init__()
self.game: "PrimaiteGame" = game
self.agent: ProxyAgent = agents[0]
self.agent: ProxyAgent = self.game.rl_agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
@@ -63,3 +63,21 @@ class PrimaiteGymEnv(gymnasium.Env):
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment."""
self.env = PrimaiteGymEnv(game=env_config["game"])
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)

View File

@@ -55,6 +55,53 @@
"source": [
"gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x7f75352526e0>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.learn(total_timesteps=1000)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model.save(\"deleteme\")"
]
}
],
"metadata": {