Add multi agent session test

This commit is contained in:
Marek Wolan
2023-11-23 02:51:31 +00:00
parent 8a2279c6cb
commit f1f516c51a
3 changed files with 1223 additions and 1 deletions

View File

@@ -2,13 +2,15 @@ from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING
from primaite.game.policy.policy import PolicyABC
from primaite.session.environment import PrimaiteRayEnv
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
@@ -54,3 +56,50 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
"""Mutli agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
"""Initialise multi agent policy wrapper."""
super().__init__(session=session)
self.config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in session.game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy paramters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate trained policy."""
return NotImplemented
def save(self, save_path: Path) -> None:
"""Save policy parameters to a file."""
return NotImplemented
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
"""Create policy from config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)