Get RLLib to stop crashing.

This commit is contained in:
Marek Wolan
2023-11-17 17:57:57 +00:00
parent 6e5e1e6456
commit 3fb7bce3ce
5 changed files with 149 additions and 66 deletions

View File

@@ -1,5 +1,5 @@
training_config:
rl_framework: SB3
rl_framework: RLLIB_single_agent
rl_algorithm: PPO
seed: 333
n_learn_episodes: 25

View File

@@ -0,0 +1,67 @@
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, TYPE_CHECKING
import gymnasium
from gymnasium.core import ActType, ObsType
from primaite.game.agent.interface import ProxyAgent
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)

View File

@@ -1,3 +1,4 @@
from primaite.game.policy.rllib import RaySingleAgentPolicy
from primaite.game.policy.sb3 import SB3Policy
__all__ = ["SB3Policy"]
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]

View File

@@ -1,13 +1,20 @@
from pathlib import Path
from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union
import gymnasium
from gymnasium.core import ActType, ObsType
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from primaite.game.policy import PolicyABC
from primaite.game.environment import PrimaiteGymEnv
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.agent.interface import ProxyAgent
from primaite.game.session import PrimaiteSession, TrainingOptions
from ray.rllib
import ray
from ray.rllib.algorithms import Algorithm, ppo
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
@@ -15,4 +22,70 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
ray.init()
class RayPrimaiteGym(gymnasium.Env):
def __init__(self, env_config: Dict) -> None:
self.action_space = session.env.action_space
self.observation_space = session.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
obs, info = session.env.reset()
return obs, info
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
obs, reward, terminated, truncated, info = session.env.step(action)
return obs, reward, terminated, truncated, info
ray.shutdown()
ray.init()
config = {
"env": RayPrimaiteGym,
"env_config": {},
"disable_env_checking": True,
"num_rollout_workers": 0,
}
self._algo = ppo.PPO(config=config)
# self._agent_config = (PPOConfig()
# .update_from_dict({
# "num_gpus":0,
# "num_workers":0,
# "batch_mode":"complete_episodes",
# "framework":"torch",
# })
# .environment(
# env="primaite",
# env_config={"session": session, "agents": session.rl_agents,},
# # disable_env_checking=True
# )
# # .rollouts(num_rollout_workers=0,
# # num_envs_per_worker=0)
# # .framework("tf2")
# .evaluation(evaluation_num_workers=0)
# )
# self._agent:Algorithm = self._agent_config.build(use_copy=False)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
for ep in range(n_episodes):
res = self._algo.train()
print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}")
def eval(self, n_episodes: int, deterministic: bool) -> None:
raise NotImplementedError
def save(self, save_path: Path) -> None:
raise NotImplementedError
def load(self, model_path: Path) -> None:
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
import enlighten
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict
@@ -14,6 +13,7 @@ from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.environment import PrimaiteGymEnv
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
@@ -39,64 +39,6 @@ progress_bar_manager = enlighten.get_manager()
_LOGGER = getLogger(__name__)
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
"""
Global options which are applicable to all of the agents in the game.
@@ -115,7 +57,7 @@ class TrainingOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_framework: Literal["SB3", "RLLIB_single_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None