From 14ae8be5e2705a17e9cff45560499ae0c1fa6706 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 23 Nov 2023 00:54:19 +0000 Subject: [PATCH] Update session after it was split from game --- src/primaite/game/policy/policy.py | 10 ++- src/primaite/game/policy/rllib.py | 66 +++++-------------- src/primaite/game/policy/sb3.py | 6 +- src/primaite/main.py | 10 +-- .../training_example_ray_multi_agent.ipynb | 4 +- .../training_example_ray_single_agent.ipynb | 8 ++- .../notebooks/training_example_sb3.ipynb | 50 ++++---------- src/primaite/{game => session}/environment.py | 0 src/primaite/session/session.py | 35 ++++++++-- 9 files changed, 76 insertions(+), 113 deletions(-) rename src/primaite/{game => session}/environment.py (100%) diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 10af44b1..984466d1 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Dict, Type, TYPE_CHECKING if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame, TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions class PolicyABC(ABC): @@ -32,7 +32,7 @@ class PolicyABC(ABC): return @abstractmethod - def __init__(self, session: "PrimaiteGame") -> None: + def __init__(self, session: "PrimaiteSession") -> None: """ Initialize a reinforcement learning policy. @@ -41,7 +41,7 @@ class PolicyABC(ABC): :param agents: The agents to train. :type agents: List[RLAgent] """ - self.session: "PrimaiteGame" = session + self.session: "PrimaiteSession" = session """Reference to the session.""" @abstractmethod @@ -69,7 +69,7 @@ class PolicyABC(ABC): pass @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC": """ Create an RL policy from a config by calling the relevant subclass's from_config method. @@ -80,5 +80,3 @@ class PolicyABC(ABC): PolicyType = cls._registry[config.rl_framework] return PolicyType.from_config(config=config, session=session) - - # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index 7828ccc7..f45b9fd6 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -1,14 +1,11 @@ from pathlib import Path -from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING - -import gymnasium -from gymnasium.core import ActType, ObsType +from typing import Literal, Optional, TYPE_CHECKING from primaite.game.policy.policy import PolicyABC +from primaite.session.environment import PrimaiteRayEnv if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - from primaite.session.session import TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions import ray from ray.rllib.algorithms import ppo @@ -17,64 +14,33 @@ from ray.rllib.algorithms import ppo class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): """Single agent RL policy using Ray RLLib.""" - def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): super().__init__(session=session) - ray.init() - - class RayPrimaiteGym(gymnasium.Env): - def __init__(self, env_config: Dict) -> None: - self.action_space = session.env.action_space - self.observation_space = session.env.observation_space - - def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: - obs, info = session.env.reset() - return obs, info - - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: - obs, reward, terminated, truncated, info = session.env.step(action) - return obs, reward, terminated, truncated, info - - ray.shutdown() - ray.init() config = { - "env": RayPrimaiteGym, - "env_config": {}, + "env": PrimaiteRayEnv, + "env_config": {"game": session.game}, "disable_env_checking": True, "num_rollout_workers": 0, } + ray.shutdown() + ray.init() + self._algo = ppo.PPO(config=config) - # self._agent_config = (PPOConfig() - # .update_from_dict({ - # "num_gpus":0, - # "num_workers":0, - # "batch_mode":"complete_episodes", - # "framework":"torch", - # }) - # .environment( - # env="primaite", - # env_config={"session": session, "agents": session.rl_agents,}, - # # disable_env_checking=True - # ) - # # .rollouts(num_rollout_workers=0, - # # num_envs_per_worker=0) - # # .framework("tf2") - # .evaluation(evaluation_num_workers=0) - # ) - - # self._agent:Algorithm = self._agent_config.build(use_copy=False) - def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" for ep in range(n_episodes): - res = self._algo.train() - print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}") + self._algo.train() def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" - raise NotImplementedError + for ep in range(n_episodes): + obs, info = self.session.env.reset() + for step in range(self.session.game.options.max_episode_length): + action = self._algo.compute_single_action(observation=obs, explore=False) + obs, rew, term, trunc, info = self.session.env.step(action) def save(self, save_path: Path) -> None: """Save the policy to a file.""" @@ -85,6 +51,6 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): raise NotImplementedError @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": """Create a policy from a config.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index de14ed0c..64eebfc7 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame, TrainingOptions + from primaite.session.session import PrimaiteSession, TrainingOptions class SB3Policy(PolicyABC, identifier="SB3"): """Single agent RL policy using stable baselines 3.""" - def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) @@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"): self._agent = self._agent_class.load(model_path, env=self.session.env) @classmethod - def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy": + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy": """Create an agent from config file.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/src/primaite/main.py b/src/primaite/main.py index 5bc76ca2..b63227a7 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -5,8 +5,8 @@ from pathlib import Path from typing import Optional, Union from primaite import getLogger -from primaite.config.load import load -from primaite.game.game import PrimaiteGame +from primaite.config.load import example_config_path, load +from primaite.session.session import PrimaiteSession # from primaite.primaite_session import PrimaiteSession @@ -32,7 +32,7 @@ def run( otherwise False. """ cfg = load(config_path) - sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path) + sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path) sess.start_session() @@ -42,6 +42,6 @@ if __name__ == "__main__": args = parser.parse_args() if not args.config: - _LOGGER.error("Please provide a config file using the --config " "argument") + args.config = example_config_path() - run(session_path=args.config) + run(args.config) diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb index 9f916af9..d31d53cc 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -10,7 +10,7 @@ "import yaml\n", "from primaite.config.load import example_config_path\n", "\n", - "from primaite.game.environment import PrimaiteRayEnv" + "from primaite.session.environment import PrimaiteRayEnv" ] }, { @@ -61,7 +61,7 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.game.environment import PrimaiteRayMARLEnv\n", + "from primaite.session.environment import PrimaiteRayMARLEnv\n", "\n", "\n", "env_config = {\"game\":game}\n", diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb index f47722f5..9b935346 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -10,7 +10,7 @@ "import yaml\n", "from primaite.config.load import example_config_path\n", "\n", - "from primaite.game.environment import PrimaiteRayEnv" + "from primaite.session.environment import PrimaiteRayEnv" ] }, { @@ -102,7 +102,11 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from primaite.config.load import example_config_path\n", + "from primaite.main import run\n", + "run(example_config_path())" + ] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/training_example_sb3.ipynb index e4033a79..e5085c5e 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/training_example_sb3.ipynb @@ -2,18 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from primaite.game.game import PrimaiteGame\n", - "from primaite.game.environment import PrimaiteGymEnv\n", + "from primaite.session.environment import PrimaiteGymEnv\n", "import yaml" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -22,24 +22,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "installing DNSServer on node domain_controller\n", - "installing DatabaseClient on node web_server\n", - "installing WebServer on node web_server\n", - "installing DatabaseService on node database_server\n", - "service type not found DatabaseBackup\n", - "installing DataManipulationBot on node client_1\n", - "installing DNSClient on node client_1\n", - "installing DNSClient on node client_2\n" - ] - } - ], + "outputs": [], "source": [ "with open(example_config_path(), 'r') as f:\n", " cfg = yaml.safe_load(f)\n", @@ -49,16 +34,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)" + "gym = PrimaiteGymEnv(game=game)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -67,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -76,27 +61,16 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.learn(total_timesteps=1000)\n" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/src/primaite/game/environment.py b/src/primaite/session/environment.py similarity index 100% rename from src/primaite/game/environment.py rename to src/primaite/session/environment.py diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index d7bc3f99..9f567a95 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -1,12 +1,14 @@ from enum import Enum -from typing import Dict, List, Literal, Optional +from pathlib import Path +from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict -from primaite.game.environment import PrimaiteGymEnv +from primaite.game.game import PrimaiteGame # from primaite.game.game import PrimaiteGame from primaite.game.policy.policy import PolicyABC +from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv from primaite.session.io import SessionIO, SessionIOSettings @@ -15,7 +17,7 @@ class TrainingOptions(BaseModel): model_config = ConfigDict(extra="forbid") - rl_framework: Literal["SB3", "RLLIB_single_agent"] + rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"] rl_algorithm: Literal["PPO", "A2C"] n_learn_episodes: int n_eval_episodes: Optional[int] = None @@ -38,7 +40,7 @@ class SessionMode(Enum): class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" - def __init__(self): + def __init__(self, game: PrimaiteGame): """Initialise PrimaiteSession object.""" self.training_options: TrainingOptions """Options specific to agent training.""" @@ -46,8 +48,8 @@ class PrimaiteSession: self.mode: SessionMode = SessionMode.MANUAL """Current session mode.""" - self.env: PrimaiteGymEnv - """The environment that the agent can consume. Could be PrimaiteEnv.""" + self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv] + """The environment that the RL algorithm can consume.""" self.policy: PolicyABC """The reinforcement learning policy.""" @@ -55,6 +57,9 @@ class PrimaiteSession: self.io_manager = SessionIO() """IO manager for the session.""" + self.game: PrimaiteGame = game + """Primaite Game object for managing main simulation loop and agents.""" + def start_session(self) -> None: """Commence the training/eval session.""" self.mode = SessionMode.TRAIN @@ -83,10 +88,26 @@ class PrimaiteSession: @classmethod def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary.""" - sess = cls() + game = PrimaiteGame.from_config(cfg) + + sess = cls(game=game) sess.training_options = TrainingOptions(**cfg["training_config"]) # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... io_settings = cfg.get("io_settings", {}) sess.io_manager.settings = SessionIOSettings(**io_settings) + + # CREATE ENVIRONMENT + if sess.training_options.rl_framework == "RLLIB_single_agent": + sess.env = PrimaiteRayEnv(env_config={"game": game}) + elif sess.training_options.rl_framework == "RLLIB_multi_agent": + sess.env = PrimaiteRayMARLEnv(env_config={"game": game}) + elif sess.training_options.rl_framework == "SB3": + sess.env = PrimaiteGymEnv(game=game) + + sess.policy = PolicyABC.from_config(sess.training_options, session=sess) + if agent_load_path: + sess.policy.load(Path(agent_load_path)) + + return sess