diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 59fd46c4..140df1b8 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -6,9 +6,14 @@ "source": [ "# Training an SB3 Agent\n", "\n", - "This notebook will demonstrate how to use primaite to create and train a PPO agent.\n", - "\n", - "#### First, import `PrimaiteGymEnv` and read our config file" + "This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### First, we import the inital packages and read in our configuration file." ] }, { @@ -38,7 +43,14 @@ "outputs": [], "source": [ "with open(data_manipulation_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n" + " cfg = yaml.safe_load(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the given configuration, we generate the environment our agent will train in." ] }, { @@ -50,6 +62,13 @@ "gym = PrimaiteGymEnv(game_config=cfg)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets define training parameters for the agent." + ] + }, { "cell_type": "code", "execution_count": null, @@ -71,7 +90,14 @@ "metadata": {}, "outputs": [], "source": [ - "model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n" + "model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the agent configured, let's train for our defined number of episodes." ] }, { @@ -80,8 +106,14 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=NO_STEPS)\n", - "model.save(\"PrimAITE-PPO-example-agent\")" + "model.learn(total_timesteps=NO_STEPS)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's save the agent to a zip file that can be used in future evaluation." ] }, { @@ -93,6 +125,13 @@ "model.save(\"PrimAITE-PPO-example-agent\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we load the saved agent and run it in evaluation mode." + ] + }, { "cell_type": "code", "execution_count": null, @@ -103,6 +142,13 @@ "eval_model = PPO.load(\"PrimAITE-PPO-example-agent\", gym)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, evaluate the agent." + ] + }, { "cell_type": "code", "execution_count": null,