Add Ray env class
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
@@ -15,11 +15,11 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, game: PrimaiteGame, agents: List[ProxyAgent]):
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.game: "PrimaiteGame" = game
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
self.agent: ProxyAgent = self.game.rl_agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
@@ -63,3 +63,21 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment."""
|
||||
self.env = PrimaiteGymEnv(game=env_config["game"])
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
@@ -55,6 +55,53 @@
|
||||
"source": [
|
||||
"gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3 import PPO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = PPO('MlpPolicy', gym)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<stable_baselines3.ppo.ppo.PPO at 0x7f75352526e0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=1000)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"deleteme\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
Reference in New Issue
Block a user