Update end to end tests after session changes

This commit is contained in:
Marek Wolan
2023-11-23 01:40:27 +00:00
parent 14ae8be5e2
commit 8a2279c6cb
10 changed files with 2099 additions and 2001 deletions

View File

@@ -0,0 +1,43 @@
import ray
import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayMARLEnv
def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": 128},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()

View File

@@ -0,0 +1,38 @@
import tempfile
from pathlib import Path
import ray
import yaml
from ray.rllib.algorithms import ppo
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayEnv
def test_rllib_single_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = {
"env": PrimaiteRayEnv,
"env_config": env_config,
"disable_env_checking": True,
"num_rollout_workers": 0,
}
algo = ppo.PPO(config=config)
for i in range(5):
result = algo.train()
save_file = Path(tempfile.gettempdir()) / "ray/"
algo.save(save_file)
assert save_file.exists()

View File

@@ -0,0 +1,27 @@
"""Test that we can create a primaite environment and train sb3 agent with no crash."""
import tempfile
from pathlib import Path
import yaml
from stable_baselines3 import PPO
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
gym = PrimaiteGymEnv(game=game)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)
save_path = Path(tempfile.gettempdir()) / "model.zip"
model.save(save_path)
assert (save_path).exists()

View File

@@ -18,15 +18,15 @@ class TestPrimaiteSession:
raise AssertionError
assert session is not None
assert session.simulation
assert len(session.agents) == 3
assert len(session.rl_agents) == 1
assert session.game.simulation
assert len(session.game.agents) == 3
assert len(session.game.rl_agents) == 1
assert session.policy
assert session.env
assert session.simulation.network
assert len(session.simulation.network.nodes) == 10
assert session.game.simulation.network
assert len(session.game.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):