Update end to end tests after session changes
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
import ray
|
||||
import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
|
||||
|
||||
def test_rllib_multi_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
||||
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
env_config = {"game": game}
|
||||
config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in game.rl_agents},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": 128},
|
||||
checkpoint_config=air.CheckpointConfig(
|
||||
checkpoint_frequency=10,
|
||||
),
|
||||
),
|
||||
param_space=config,
|
||||
).fit()
|
||||
@@ -0,0 +1,38 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import ray
|
||||
import yaml
|
||||
from ray.rllib.algorithms import ppo
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
|
||||
|
||||
def test_rllib_single_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
env_config = {"game": game}
|
||||
config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
"env_config": env_config,
|
||||
"disable_env_checking": True,
|
||||
"num_rollout_workers": 0,
|
||||
}
|
||||
|
||||
algo = ppo.PPO(config=config)
|
||||
|
||||
for i in range(5):
|
||||
result = algo.train()
|
||||
|
||||
save_file = Path(tempfile.gettempdir()) / "ray/"
|
||||
algo.save(save_file)
|
||||
assert save_file.exists()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Test that we can create a primaite environment and train sb3 agent with no crash."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
def test_sb3_compatibility():
|
||||
"""Test that the Gymnasium environment can be used with an SB3 agent."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
gym = PrimaiteGymEnv(game=game)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
|
||||
save_path = Path(tempfile.gettempdir()) / "model.zip"
|
||||
model.save(save_path)
|
||||
|
||||
assert (save_path).exists()
|
||||
@@ -18,15 +18,15 @@ class TestPrimaiteSession:
|
||||
raise AssertionError
|
||||
|
||||
assert session is not None
|
||||
assert session.simulation
|
||||
assert len(session.agents) == 3
|
||||
assert len(session.rl_agents) == 1
|
||||
assert session.game.simulation
|
||||
assert len(session.game.agents) == 3
|
||||
assert len(session.game.rl_agents) == 1
|
||||
|
||||
assert session.policy
|
||||
assert session.env
|
||||
|
||||
assert session.simulation.network
|
||||
assert len(session.simulation.network.nodes) == 10
|
||||
assert session.game.simulation.network
|
||||
assert len(session.game.simulation.network.nodes) == 10
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_start_session(self, temp_primaite_session):
|
||||
|
||||
Reference in New Issue
Block a user