152 lines
3.3 KiB
Plaintext
152 lines
3.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Simple multi-processing demonstration\n",
|
|
"\n",
|
|
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
|
"\n",
|
|
"This notebook uses SubprocVecEnv from SB3."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Import packages and read config file."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import yaml\n",
|
|
"from stable_baselines3 import PPO\n",
|
|
"from stable_baselines3.common.utils import set_random_seed\n",
|
|
"from stable_baselines3.common.vec_env import SubprocVecEnv\n",
|
|
"\n",
|
|
"from primaite.session.environment import PrimaiteGymEnv\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from primaite.config.load import data_manipulation_config_path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
|
" cfg = yaml.safe_load(f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Set up training data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"EPISODE_LEN = 128\n",
|
|
"NUM_EPISODES = 10\n",
|
|
"NO_STEPS = EPISODE_LEN * NUM_EPISODES\n",
|
|
"BATCH_SIZE = 32\n",
|
|
"LEARNING_RATE = 3e-4\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Define an environment function."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"def make_env(rank: int, seed: int = 0) -> callable:\n",
|
|
" \"\"\"Wrapper script for _init function.\"\"\"\n",
|
|
"\n",
|
|
" def _init() -> PrimaiteGymEnv:\n",
|
|
" env = PrimaiteGymEnv(env_config=cfg)\n",
|
|
" env.reset(seed=seed + rank)\n",
|
|
" model = PPO(\n",
|
|
" \"MlpPolicy\",\n",
|
|
" env,\n",
|
|
" learning_rate=LEARNING_RATE,\n",
|
|
" n_steps=NO_STEPS,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" verbose=0,\n",
|
|
" tensorboard_log=\"./PPO_UC2/\",\n",
|
|
" )\n",
|
|
" model.learn(total_timesteps=NO_STEPS)\n",
|
|
" return env\n",
|
|
"\n",
|
|
" set_random_seed(seed)\n",
|
|
" return _init\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Run experiment."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"n_procs = 2\n",
|
|
"train_env = SubprocVecEnv([make_env(i + n_procs) for i in range(n_procs)])\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|