Add Ray env class

This commit is contained in:
Marek Wolan
2023-11-22 13:12:08 +00:00
parent 1138644a4b
commit b81dd26b71
2 changed files with 68 additions and 3 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple from typing import Any, Dict, Optional, SupportsFloat, Tuple
import gymnasium import gymnasium
from gymnasium.core import ActType, ObsType from gymnasium.core import ActType, ObsType
@@ -15,11 +15,11 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1. assumptions about the agent list always having a list of length 1.
""" """
def __init__(self, game: PrimaiteGame, agents: List[ProxyAgent]): def __init__(self, game: PrimaiteGame):
"""Initialise the environment.""" """Initialise the environment."""
super().__init__() super().__init__()
self.game: "PrimaiteGame" = game self.game: "PrimaiteGame" = game
self.agent: ProxyAgent = agents[0] self.agent: ProxyAgent = self.game.rl_agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment.""" """Perform a step in the environment."""
@@ -63,3 +63,21 @@ class PrimaiteGymEnv(gymnasium.Env):
unflat_space = self.agent.observation_manager.space unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs) return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment."""
self.env = PrimaiteGymEnv(game=env_config["game"])
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)

View File

@@ -55,6 +55,53 @@
"source": [ "source": [
"gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)" "gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)"
] ]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x7f75352526e0>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.learn(total_timesteps=1000)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model.save(\"deleteme\")"
]
} }
], ],
"metadata": { "metadata": {