Add Ray env class
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple
|
from typing import Any, Dict, Optional, SupportsFloat, Tuple
|
||||||
|
|
||||||
import gymnasium
|
import gymnasium
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
@@ -15,11 +15,11 @@ class PrimaiteGymEnv(gymnasium.Env):
|
|||||||
assumptions about the agent list always having a list of length 1.
|
assumptions about the agent list always having a list of length 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, game: PrimaiteGame, agents: List[ProxyAgent]):
|
def __init__(self, game: PrimaiteGame):
|
||||||
"""Initialise the environment."""
|
"""Initialise the environment."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.game: "PrimaiteGame" = game
|
self.game: "PrimaiteGame" = game
|
||||||
self.agent: ProxyAgent = agents[0]
|
self.agent: ProxyAgent = self.game.rl_agents[0]
|
||||||
|
|
||||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||||
"""Perform a step in the environment."""
|
"""Perform a step in the environment."""
|
||||||
@@ -63,3 +63,21 @@ class PrimaiteGymEnv(gymnasium.Env):
|
|||||||
unflat_space = self.agent.observation_manager.space
|
unflat_space = self.agent.observation_manager.space
|
||||||
unflat_obs = self.agent.observation_manager.current_observation
|
unflat_obs = self.agent.observation_manager.current_observation
|
||||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||||
|
|
||||||
|
|
||||||
|
class PrimaiteRayEnv(gymnasium.Env):
|
||||||
|
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||||
|
|
||||||
|
def __init__(self, env_config: Dict) -> None:
|
||||||
|
"""Initialise the environment."""
|
||||||
|
self.env = PrimaiteGymEnv(game=env_config["game"])
|
||||||
|
self.action_space = self.env.action_space
|
||||||
|
self.observation_space = self.env.observation_space
|
||||||
|
|
||||||
|
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||||
|
"""Reset the environment."""
|
||||||
|
return self.env.reset(seed=seed)
|
||||||
|
|
||||||
|
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||||
|
"""Perform a step in the environment."""
|
||||||
|
return self.env.step(action)
|
||||||
|
|||||||
@@ -55,6 +55,53 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)"
|
"gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from stable_baselines3 import PPO"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = PPO('MlpPolicy', gym)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<stable_baselines3.ppo.ppo.PPO at 0x7f75352526e0>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model.learn(total_timesteps=1000)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model.save(\"deleteme\")"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
Reference in New Issue
Block a user