Add multi agent session test

This commit is contained in:
Marek Wolan
2023-11-23 02:51:31 +00:00
parent 8a2279c6cb
commit f1f516c51a
3 changed files with 1223 additions and 1 deletions

View File

@@ -2,13 +2,15 @@ from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING
from primaite.game.policy.policy import PolicyABC
from primaite.session.environment import PrimaiteRayEnv
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
@@ -54,3 +56,50 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
"""Mutli agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
"""Initialise multi agent policy wrapper."""
super().__init__(session=session)
self.config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in session.game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy paramters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate trained policy."""
return NotImplemented
def save(self, save_path: Path) -> None:
"""Save policy parameters to a file."""
return NotImplemented
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
"""Create policy from config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,7 @@ CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = "tests/assets/configs/multi_agent_session.yaml"
class TestPrimaiteSession:
@@ -63,6 +64,12 @@ class TestPrimaiteSession:
session.start_session()
# TODO: include checks that the model was loaded and that the eval-only session ran
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
def test_multi_agent_session(self, temp_primaite_session):
"""Check that we can run a training session with a multi agent system."""
with temp_primaite_session as session:
session.start_session()
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)