From 9070fb44d4451b36226bd48af6e10a8fe92d5dd6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 22 Nov 2023 13:26:29 +0000 Subject: [PATCH] Check that ray single agent training works --- src/primaite/game/environment.py | 9 +- .../training_example_ray_single_agent.ipynb | 129 ++++++++++++++++++ 2 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 src/primaite/notebooks/training_example_ray_single_agent.ipynb diff --git a/src/primaite/game/environment.py b/src/primaite/game/environment.py index d540bd02..8ddcb88a 100644 --- a/src/primaite/game/environment.py +++ b/src/primaite/game/environment.py @@ -68,8 +68,13 @@ class PrimaiteGymEnv(gymnasium.Env): class PrimaiteRayEnv(gymnasium.Env): """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" - def __init__(self, env_config: Dict) -> None: - """Initialise the environment.""" + def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict[str, PrimaiteGame] + """ self.env = PrimaiteGymEnv(game=env_config["game"]) self.action_space = self.env.action_space self.observation_space = self.env.observation_space diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/training_example_ray_single_agent.ipynb new file mode 100644 index 00000000..f47722f5 --- /dev/null +++ b/src/primaite/notebooks/training_example_ray_single_agent.ipynb @@ -0,0 +1,129 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.game.environment import PrimaiteRayEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "game = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gym = PrimaiteRayEnv({\"game\":game})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "from ray.rllib.algorithms import ppo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()\n", + "ray.init()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env_config = {\"game\":game}\n", + "config = {\n", + " \"env\" : PrimaiteRayEnv,\n", + " \"env_config\" : env_config,\n", + " \"disable_env_checking\": True,\n", + " \"num_rollout_workers\": 0,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "algo = ppo.PPO(config=config)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(5):\n", + " result = algo.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "algo.save(\"temp/deleteme\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}