get ray to stop crashing
This commit is contained in:
@@ -54,7 +54,7 @@ license-files = ["LICENSE"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
rl = [
|
||||
"ray[rllib] >= 2.9, < 3",
|
||||
"ray[rllib] >= 2.20.0, < 3",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
]
|
||||
|
||||
@@ -60,8 +60,8 @@
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n"
|
||||
]
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms import ppo\n",
|
||||
"from ray import air, tune\n",
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
@@ -52,8 +51,8 @@
|
||||
"\n",
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n"
|
||||
]
|
||||
@@ -74,7 +73,7 @@
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 5 * 128}\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
|
||||
@@ -43,7 +43,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)"
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"for agent in cfg['agents']:\n",
|
||||
" if agent['ref'] == 'defender':\n",
|
||||
" agent['agent_settings']['flatten_obs']=True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -177,7 +180,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -45,7 +45,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
self._obs_space_in_preferred_format = True
|
||||
self._action_space_in_preferred_format = True
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user