Get RLLib to stop crashing.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_framework: RLLIB_single_agent
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 25
|
||||
|
||||
67
src/primaite/game/environment.py
Normal file
67
src/primaite/game/environment.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, TYPE_CHECKING
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.session.apply_agent_actions()
|
||||
self.session.advance_timestep()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.session.calculate_truncated()
|
||||
info = {}
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
@@ -1,3 +1,4 @@
|
||||
from primaite.game.policy.rllib import RaySingleAgentPolicy
|
||||
from primaite.game.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy"]
|
||||
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.game.policy import PolicyABC
|
||||
from primaite.game.environment import PrimaiteGymEnv
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
from ray.rllib
|
||||
import ray
|
||||
from ray.rllib.algorithms import Algorithm, ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
@@ -15,4 +22,70 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
super().__init__(session=session)
|
||||
ray.init()
|
||||
|
||||
class RayPrimaiteGym(gymnasium.Env):
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
self.action_space = session.env.action_space
|
||||
self.observation_space = session.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
obs, info = session.env.reset()
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
obs, reward, terminated, truncated, info = session.env.step(action)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
config = {
|
||||
"env": RayPrimaiteGym,
|
||||
"env_config": {},
|
||||
"disable_env_checking": True,
|
||||
"num_rollout_workers": 0,
|
||||
}
|
||||
|
||||
self._algo = ppo.PPO(config=config)
|
||||
|
||||
# self._agent_config = (PPOConfig()
|
||||
# .update_from_dict({
|
||||
# "num_gpus":0,
|
||||
# "num_workers":0,
|
||||
# "batch_mode":"complete_episodes",
|
||||
# "framework":"torch",
|
||||
# })
|
||||
# .environment(
|
||||
# env="primaite",
|
||||
# env_config={"session": session, "agents": session.rl_agents,},
|
||||
# # disable_env_checking=True
|
||||
# )
|
||||
# # .rollouts(num_rollout_workers=0,
|
||||
# # num_envs_per_worker=0)
|
||||
# # .framework("tf2")
|
||||
# .evaluation(evaluation_num_workers=0)
|
||||
# )
|
||||
|
||||
# self._agent:Algorithm = self._agent_config.build(use_copy=False)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
|
||||
for ep in range(n_episodes):
|
||||
res = self._algo.train()
|
||||
print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}")
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
|
||||
import enlighten
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -14,6 +13,7 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.environment import PrimaiteGymEnv
|
||||
from primaite.game.io import SessionIO, SessionIOSettings
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
@@ -39,64 +39,6 @@ progress_bar_manager = enlighten.get_manager()
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.session.apply_agent_actions()
|
||||
self.session.advance_timestep()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.session.calculate_truncated()
|
||||
info = {}
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
"""
|
||||
Global options which are applicable to all of the agents in the game.
|
||||
@@ -115,7 +57,7 @@ class TrainingOptions(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
|
||||
Reference in New Issue
Block a user