Add multi agent session test
This commit is contained in:
@@ -2,13 +2,15 @@ from pathlib import Path
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
import ray
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms import ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
@@ -54,3 +56,50 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
|
||||
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
|
||||
"""Mutli agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
|
||||
"""Initialise multi agent policy wrapper."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self.config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in session.game.rl_agents},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": n_episodes * timesteps_per_episode},
|
||||
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
|
||||
),
|
||||
param_space=self.config,
|
||||
).fit()
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy paramters from a file."""
|
||||
return NotImplemented
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate trained policy."""
|
||||
return NotImplemented
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save policy parameters to a file."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
|
||||
"""Create policy from config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
Reference in New Issue
Block a user