Check that ray single agent training works

This commit is contained in:
Marek Wolan
2023-11-22 13:26:29 +00:00
parent b81dd26b71
commit 9070fb44d4
2 changed files with 136 additions and 2 deletions

View File

@@ -68,8 +68,13 @@ class PrimaiteGymEnv(gymnasium.Env):
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment."""
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.env = PrimaiteGymEnv(game=env_config["game"])
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space

View File

@@ -0,0 +1,129 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.game.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteRayEnv({\"game\":game})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray.rllib.algorithms import ppo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"game\":game}\n",
"config = {\n",
" \"env\" : PrimaiteRayEnv,\n",
" \"env_config\" : env_config,\n",
" \"disable_env_checking\": True,\n",
" \"num_rollout_workers\": 0,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo = ppo.PPO(config=config)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(5):\n",
" result = algo.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo.save(\"temp/deleteme\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}