- Improve clarity of some code cells (narrowed output) - Reworded some questionably worded sections - Updated some of the util functionality that using old action names - Updated a lot of old names into kebab-case - General tidy up and consistency changes.
132 lines
3.2 KiB
Plaintext
132 lines
3.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"source": [
|
|
"# Training an Agent on UC7\n",
|
|
"\n",
|
|
"© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n",
|
|
"\n",
|
|
"This notebook is identical in content to the [training an SB3 agent](./Training-an-SB3-Agent.ipynb) except this notebook trains an agent on the [use case 7 scenario](./UC7-E2E-Demo.ipynb) rather than [use case 2](./Data-Manipulation-E2E-Demonstration.ipynb). By default, the `uc7_config.yaml` blue agent (`defender`) is setup to defend against Threat Actor Profile (TAP) 001 which can be explored in more detail [here](./UC7-TAP001-Kill-Chain-E2E.ipynb).\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### First, we import the inital packages and read in our configuration file."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!primaite setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import yaml\n",
|
|
"from primaite.session.environment import PrimaiteGymEnv\n",
|
|
"from primaite import PRIMAITE_PATHS\n",
|
|
"from prettytable import PrettyTable\n",
|
|
"from deepdiff.diff import DeepDiff\n",
|
|
"scenario_path = PRIMAITE_PATHS.user_config_path / \"example_config/uc7_config.yaml\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gym = PrimaiteGymEnv(env_config=scenario_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from stable_baselines3 import PPO\n",
|
|
"\n",
|
|
"# EPISODE_LEN = 128\n",
|
|
"EPISODE_LEN = 128\n",
|
|
"NUM_EPISODES = 10\n",
|
|
"NO_STEPS = EPISODE_LEN * NUM_EPISODES\n",
|
|
"BATCH_SIZE = 32\n",
|
|
"LEARNING_RATE = 3e-4"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC7/\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.learn(total_timesteps=NO_STEPS)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.save(\"PrimAITE-PPO-UC7-example-agent\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"eval_model = PPO(\"MlpPolicy\", gym)\n",
|
|
"eval_model = PPO.load(\"PrimAITE-PPO-UC7-example-agent\", gym)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from stable_baselines3.common.evaluation import evaluate_policy\n",
|
|
"\n",
|
|
"evaluate_policy(eval_model, gym, n_eval_episodes=1)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|