Files
PrimAITE/src/primaite/notebooks/UC7-Training.ipynb

132 lines
3.2 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"source": [
"# Training an Agent on UC7\n",
"\n",
"© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n",
"\n",
"This notebook is identical in content to the [training an SB3 agent](./Training-an-SB3-Agent.ipynb) except this notebook trains an agent on the [use case 7 scenario](./UC7-E2E-Demo.ipynb) rather than [use case 2](./Data-Manipulation-E2E-Demonstration.ipynb). By default, the `uc7_config.yaml` blue agent (`defender`) is setup to defend against Threat Actor Profile (TAP) 001 which can be explored in more detail [here](./UC7-TAP001-Kill-Chain-E2E.ipynb).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### First, we import the inital packages and read in our configuration file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"from prettytable import PrettyTable\n",
"from deepdiff.diff import DeepDiff\n",
"scenario_path = PRIMAITE_PATHS.user_config_path / \"example_config/uc7_config.yaml\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(env_config=scenario_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO\n",
"\n",
"# EPISODE_LEN = 128\n",
"EPISODE_LEN = 128\n",
"NUM_EPISODES = 10\n",
"NO_STEPS = EPISODE_LEN * NUM_EPISODES\n",
"BATCH_SIZE = 32\n",
"LEARNING_RATE = 3e-4"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC7/\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=NO_STEPS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save(\"PrimAITE-PPO-UC7-example-agent\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_model = PPO(\"MlpPolicy\", gym)\n",
"eval_model = PPO.load(\"PrimAITE-PPO-UC7-example-agent\", gym)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"\n",
"evaluate_policy(eval_model, gym, n_eval_episodes=1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}