diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index e0ff9276..c581ae49 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -1,5 +1,5 @@ training_config: - rl_framework: SB3 + rl_framework: RLLIB_single_agent rl_algorithm: PPO seed: 333 n_learn_episodes: 25 diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py new file mode 100644 index 00000000..b88a8202 --- /dev/null +++ b/src/primaite/game/environment.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, TYPE_CHECKING + +import gymnasium +from gymnasium.core import ActType, ObsType + +from primaite.game.agent.interface import ProxyAgent + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession + + +class PrimaiteGymEnv(gymnasium.Env): + """ + Thin wrapper env to provide agents with a gymnasium API. + + This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some + assumptions about the agent list always having a list of length 1. + """ + + def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): + """Initialise the environment.""" + super().__init__() + self.session: "PrimaiteSession" = session + self.agent: ProxyAgent = agents[0] + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """Perform a step in the environment.""" + # make ProxyAgent store the action chosen my the RL policy + self.agent.store_action(action) + # apply_agent_actions accesses the action we just stored + self.session.apply_agent_actions() + self.session.advance_timestep() + state = self.session.get_sim_state() + self.session.update_agents(state) + + next_obs = self._get_obs() + reward = self.agent.reward_function.current_reward + terminated = False + truncated = self.session.calculate_truncated() + info = {} + + return next_obs, reward, terminated, truncated, info + + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: + """Reset the environment.""" + self.session.reset() + state = self.session.get_sim_state() + self.session.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + @property + def action_space(self) -> gymnasium.Space: + """Return the action space of the environment.""" + return self.agent.action_manager.space + + @property + def observation_space(self) -> gymnasium.Space: + """Return the observation space of the environment.""" + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + + def _get_obs(self) -> ObsType: + """Return the current observation.""" + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py index 29196112..9c0e4199 100644 --- a/src/primaite/game/policy/__init__.py +++ b/src/primaite/game/policy/__init__.py @@ -1,3 +1,4 @@ +from primaite.game.policy.rllib import RaySingleAgentPolicy from primaite.game.policy.sb3 import SB3Policy -__all__ = ["SB3Policy"] +__all__ = ["SB3Policy", "RaySingleAgentPolicy"] diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index 721a7500..6e9e1096 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -1,13 +1,20 @@ +from pathlib import Path +from typing import Dict, List, Literal, Optional, SupportsFloat, Tuple, Type, TYPE_CHECKING, Union +import gymnasium +from gymnasium.core import ActType, ObsType -from typing import Literal, Optional, Type, TYPE_CHECKING, Union - -from primaite.game.policy import PolicyABC +from primaite.game.environment import PrimaiteGymEnv +from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: + from primaite.game.agent.interface import ProxyAgent from primaite.game.session import PrimaiteSession, TrainingOptions -from ray.rllib +import ray +from ray.rllib.algorithms import Algorithm, ppo +from ray.rllib.algorithms.ppo import PPOConfig +from ray.tune.registry import register_env class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): @@ -15,4 +22,70 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): super().__init__(session=session) + ray.init() + class RayPrimaiteGym(gymnasium.Env): + def __init__(self, env_config: Dict) -> None: + self.action_space = session.env.action_space + self.observation_space = session.env.observation_space + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + obs, info = session.env.reset() + return obs, info + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: + obs, reward, terminated, truncated, info = session.env.step(action) + return obs, reward, terminated, truncated, info + + ray.shutdown() + ray.init() + + config = { + "env": RayPrimaiteGym, + "env_config": {}, + "disable_env_checking": True, + "num_rollout_workers": 0, + } + + self._algo = ppo.PPO(config=config) + + # self._agent_config = (PPOConfig() + # .update_from_dict({ + # "num_gpus":0, + # "num_workers":0, + # "batch_mode":"complete_episodes", + # "framework":"torch", + # }) + # .environment( + # env="primaite", + # env_config={"session": session, "agents": session.rl_agents,}, + # # disable_env_checking=True + # ) + # # .rollouts(num_rollout_workers=0, + # # num_envs_per_worker=0) + # # .framework("tf2") + # .evaluation(evaluation_num_workers=0) + # ) + + # self._agent:Algorithm = self._agent_config.build(use_copy=False) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + + for ep in range(n_episodes): + res = self._algo.train() + print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}") + + def eval(self, n_episodes: int, deterministic: bool) -> None: + raise NotImplementedError + + def save(self, save_path: Path) -> None: + raise NotImplementedError + + def load(self, model_path: Path) -> None: + raise NotImplementedError + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": + """Create a policy from a config.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index a2c04980..aae26fab 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple import enlighten -import gymnasium from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict @@ -14,6 +13,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.environment import PrimaiteGymEnv from primaite.game.io import SessionIO, SessionIOSettings from primaite.game.policy.policy import PolicyABC from primaite.simulator.network.hardware.base import Link, NIC, Node @@ -39,64 +39,6 @@ progress_bar_manager = enlighten.get_manager() _LOGGER = getLogger(__name__) -class PrimaiteGymEnv(gymnasium.Env): - """ - Thin wrapper env to provide agents with a gymnasium API. - - This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some - assumptions about the agent list always having a list of length 1. - """ - - def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]): - """Initialise the environment.""" - super().__init__() - self.session: "PrimaiteSession" = session - self.agent: ProxyAgent = agents[0] - - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """Perform a step in the environment.""" - # make ProxyAgent store the action chosen my the RL policy - self.agent.store_action(action) - # apply_agent_actions accesses the action we just stored - self.session.apply_agent_actions() - self.session.advance_timestep() - state = self.session.get_sim_state() - self.session.update_agents(state) - - next_obs = self._get_obs() - reward = self.agent.reward_function.current_reward - terminated = False - truncated = self.session.calculate_truncated() - info = {} - - return next_obs, reward, terminated, truncated, info - - def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: - """Reset the environment.""" - self.session.reset() - state = self.session.get_sim_state() - self.session.update_agents(state) - next_obs = self._get_obs() - info = {} - return next_obs, info - - @property - def action_space(self) -> gymnasium.Space: - """Return the action space of the environment.""" - return self.agent.action_manager.space - - @property - def observation_space(self) -> gymnasium.Space: - """Return the observation space of the environment.""" - return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) - - def _get_obs(self) -> ObsType: - """Return the current observation.""" - unflat_space = self.agent.observation_manager.space - unflat_obs = self.agent.observation_manager.current_observation - return gymnasium.spaces.flatten(unflat_space, unflat_obs) - - class PrimaiteSessionOptions(BaseModel): """ Global options which are applicable to all of the agents in the game. @@ -115,7 +57,7 @@ class TrainingOptions(BaseModel): model_config = ConfigDict(extra="forbid") - rl_framework: Literal["SB3", "RLLIB"] + rl_framework: Literal["SB3", "RLLIB_single_agent"] rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int n_eval_episodes: Optional[int] = None