diff --git a/src/primaite/notebooks/multi-processing.ipynb b/src/primaite/notebooks/multi-processing.ipynb new file mode 100644 index 00000000..83366b54 --- /dev/null +++ b/src/primaite/notebooks/multi-processing.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple multi-processing demo using SubprocVecEnv from SB3\n", + "Based on a code example provided by Rachael Proctor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#!primaite setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import packages and read config file." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up training data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import multiprocessing as mp\n", + "mp.get_all_start_methods()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "from stable_baselines3 import PPO\n", + "from stable_baselines3.common.utils import set_random_seed\n", + "from stable_baselines3.common.vec_env import SubprocVecEnv\n", + "\n", + "from primaite.session.environment import PrimaiteGymEnv\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "EPISODE_LEN = 128\n", + "NUM_EPISODES = 10\n", + "NO_STEPS = EPISODE_LEN * NUM_EPISODES\n", + "BATCH_SIZE = 32\n", + "LEARNING_RATE = 3e-4\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "with open(\"c:/projects/primaite/src/primaite/config/_package_data/data_manipulation.yaml\", \"r\") as f:\n", + " cfg = yaml.safe_load(f)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "def make_env(rank: int, seed: int = 0) -> callable:\n", + " \"\"\"Wrapper script for _init function.\"\"\"\n", + "\n", + " def _init() -> PrimaiteGymEnv:\n", + " env = PrimaiteGymEnv(env_config=cfg)\n", + " env.reset(seed=seed + rank)\n", + " model = PPO(\n", + " \"MlpPolicy\",\n", + " env,\n", + " learning_rate=LEARNING_RATE,\n", + " n_steps=NO_STEPS,\n", + " batch_size=BATCH_SIZE,\n", + " verbose=0,\n", + " tensorboard_log=\"./PPO_UC2/\",\n", + " )\n", + " model.learn(total_timesteps=NO_STEPS)\n", + " return env\n", + "\n", + " set_random_seed(seed)\n", + " return _init\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "n_procs = 4\n", + "train_env = SubprocVecEnv([make_env(i + n_procs) for i in range(n_procs)])\n", + "print(train_env)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}