Files
PrimAITE/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb
Charlie Crane 85f03570f7 Merged PR 376: 2457 - Remove Hardcoding from Links
## Summary
This PR removes the hardcoding of Link bandwidth, and makes it possible to be configured via the network yaml definitions.
Link bandwidth will still default to 100 if this is not present, to prevent breaking all previous defined networks.

## Test process
All tests continue to pass.
`basic_network_config.yaml` now provides a non-default link bandwidth which is confirmed within unit tests.

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [X] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [X] attended to any **TO-DOs** left in the code

Related work items: #2457
2024-05-22 11:35:48 +00:00

106 lines
2.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Single agent system using RLLib\n",
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import data_manipulation_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from ray.rllib.algorithms import ppo\n",
"from ray import air, tune\n",
"import ray\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"ray.init(local_mode=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create a Ray algorithm and pass it our config."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for agent in cfg['agents']:\n",
" if agent[\"ref\"] == \"defender\":\n",
" agent['agent_settings']['flatten_obs'] = True\n",
"env_config = cfg\n",
"\n",
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}