{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train a Single agent system using RLLib\n", "\n", "© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK\n", "\n", "This notebook demonstrates how to use the ``PrimaiteRayEnv`` to train a basic PPO agent on the [UC2 scenario](./Data-Manipulation-E2E-Demonstration.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!primaite setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "import ray\n", "from primaite.config.load import data_manipulation_config_path\n", "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "\n", "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", "# to copy the files to your user data path.\n", "with open(data_manipulation_config_path(), 'r') as f:\n", " cfg = yaml.safe_load(f)\n", "\n", "ray.init(local_mode=True)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create a Ray algorithm and pass it our config." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for agent in cfg['agents']:\n", " if agent[\"ref\"] == \"defender\":\n", " agent['agent_settings']['flatten_obs'] = True\n", "env_config = cfg\n", "\n", "config = (\n", " PPOConfig()\n", " .environment(env=PrimaiteRayEnv, env_config=env_config)\n", " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", " .evaluation(evaluation_duration=1)\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Start the training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "algo = config.build()\n", "results = algo.train()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate the results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "eval = algo.evaluate()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }