Add multi agent session test
This commit is contained in:
@@ -2,13 +2,15 @@ from pathlib import Path
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
import ray
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms import ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
@@ -54,3 +56,50 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
|
||||
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
|
||||
"""Mutli agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
|
||||
"""Initialise multi agent policy wrapper."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self.config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in session.game.rl_agents},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": n_episodes * timesteps_per_episode},
|
||||
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
|
||||
),
|
||||
param_space=self.config,
|
||||
).fit()
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy paramters from a file."""
|
||||
return NotImplemented
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate trained policy."""
|
||||
return NotImplemented
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save policy parameters to a file."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
|
||||
"""Create policy from config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
1166
tests/assets/configs/multi_agent_session.yaml
Normal file
1166
tests/assets/configs/multi_agent_session.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
|
||||
MULTI_AGENT_PATH = "tests/assets/configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
class TestPrimaiteSession:
|
||||
@@ -63,6 +64,12 @@ class TestPrimaiteSession:
|
||||
session.start_session()
|
||||
# TODO: include checks that the model was loaded and that the eval-only session ran
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
|
||||
def test_multi_agent_session(self, temp_primaite_session):
|
||||
"""Check that we can run a training session with a multi agent system."""
|
||||
with temp_primaite_session as session:
|
||||
session.start_session()
|
||||
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
|
||||
|
||||
Reference in New Issue
Block a user