#2453 - Committing additional explanations to notebook

This commit is contained in:
Charlie Crane
2024-04-18 13:52:43 +01:00
parent bb88d43b90
commit abf94fc4bb

View File

@@ -6,9 +6,14 @@
"source": [
"# Training an SB3 Agent\n",
"\n",
"This notebook will demonstrate how to use primaite to create and train a PPO agent.\n",
"\n",
"#### First, import `PrimaiteGymEnv` and read our config file"
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### First, we import the inital packages and read in our configuration file."
]
},
{
@@ -38,7 +43,14 @@
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n"
" cfg = yaml.safe_load(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the given configuration, we generate the environment our agent will train in."
]
},
{
@@ -50,6 +62,13 @@
"gym = PrimaiteGymEnv(game_config=cfg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets define training parameters for the agent."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -71,7 +90,14 @@
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n"
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the agent configured, let's train for our defined number of episodes."
]
},
{
@@ -80,8 +106,14 @@
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=NO_STEPS)\n",
"model.save(\"PrimAITE-PPO-example-agent\")"
"model.learn(total_timesteps=NO_STEPS)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's save the agent to a zip file that can be used in future evaluation."
]
},
{
@@ -93,6 +125,13 @@
"model.save(\"PrimAITE-PPO-example-agent\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we load the saved agent and run it in evaluation mode."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -103,6 +142,13 @@
"eval_model = PPO.load(\"PrimAITE-PPO-example-agent\", gym)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, evaluate the agent."
]
},
{
"cell_type": "code",
"execution_count": null,