From e48b71ea1a1627009c8259b1f44ce570aecf113d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 31 May 2024 15:25:08 +0100 Subject: [PATCH] get ray to stop crashing --- pyproject.toml | 2 +- src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb | 4 ++-- src/primaite/notebooks/Training-an-RLLib-Agent.ipynb | 7 +++---- src/primaite/notebooks/Training-an-SB3-Agent.ipynb | 7 +++++-- src/primaite/session/ray_envs.py | 3 ++- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9c94a388..d01299be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ - "ray[rllib] >= 2.9, < 3", + "ray[rllib] >= 2.20.0, < 3", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", ] diff --git a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index 61b988c6..5ffb19ad 100644 --- a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -60,8 +60,8 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", - " .rollouts(num_rollout_workers=0)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n", + " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", " )\n" ] diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index 9d458426..fbc5f4c6 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -19,7 +19,6 @@ "from primaite.config.load import data_manipulation_config_path\n", "\n", "from primaite.session.ray_envs import PrimaiteRayEnv\n", - "from ray.rllib.algorithms import ppo\n", "from ray import air, tune\n", "import ray\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", @@ -52,8 +51,8 @@ "\n", "config = (\n", " PPOConfig()\n", - " .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n", - " .rollouts(num_rollout_workers=0)\n", + " .environment(env=PrimaiteRayEnv, env_config=env_config)\n", + " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", ")\n" ] @@ -74,7 +73,7 @@ "tune.Tuner(\n", " \"PPO\",\n", " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 5 * 128}\n", + " stop={\"timesteps_total\": 512}\n", " ),\n", " param_space=config\n", ").fit()\n" diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 9faf5820..1e247e81 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -43,7 +43,10 @@ "outputs": [], "source": [ "with open(data_manipulation_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)" + " cfg = yaml.safe_load(f)\n", + "for agent in cfg['agents']:\n", + " if agent['ref'] == 'defender':\n", + " agent['agent_settings']['flatten_obs']=True" ] }, { @@ -177,7 +180,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 6dddde51..111baf84 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -45,7 +45,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.action_space = gymnasium.spaces.Dict( {name: agent.action_manager.space for name, agent in self.agents.items()} ) - + self._obs_space_in_preferred_format = True + self._action_space_in_preferred_format = True super().__init__() @property