Update end to end tests after session changes
This commit is contained in:
@@ -96,17 +96,6 @@
|
||||
"source": [
|
||||
"algo.save(\"temp/deleteme\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"from primaite.main import run\n",
|
||||
"run(example_config_path())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -12,6 +12,7 @@ import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.session import PrimaiteSession
|
||||
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
@@ -74,7 +75,7 @@ def file_system() -> FileSystem:
|
||||
|
||||
|
||||
# PrimAITE v2 stuff
|
||||
class TempPrimaiteSession(PrimaiteGame):
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import ray
|
||||
import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
|
||||
|
||||
def test_rllib_multi_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
||||
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
env_config = {"game": game}
|
||||
config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in game.rl_agents},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": 128},
|
||||
checkpoint_config=air.CheckpointConfig(
|
||||
checkpoint_frequency=10,
|
||||
),
|
||||
),
|
||||
param_space=config,
|
||||
).fit()
|
||||
@@ -0,0 +1,38 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import ray
|
||||
import yaml
|
||||
from ray.rllib.algorithms import ppo
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
|
||||
|
||||
def test_rllib_single_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
env_config = {"game": game}
|
||||
config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
"env_config": env_config,
|
||||
"disable_env_checking": True,
|
||||
"num_rollout_workers": 0,
|
||||
}
|
||||
|
||||
algo = ppo.PPO(config=config)
|
||||
|
||||
for i in range(5):
|
||||
result = algo.train()
|
||||
|
||||
save_file = Path(tempfile.gettempdir()) / "ray/"
|
||||
algo.save(save_file)
|
||||
assert save_file.exists()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Test that we can create a primaite environment and train sb3 agent with no crash."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
def test_sb3_compatibility():
|
||||
"""Test that the Gymnasium environment can be used with an SB3 agent."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
gym = PrimaiteGymEnv(game=game)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
|
||||
save_path = Path(tempfile.gettempdir()) / "model.zip"
|
||||
model.save(save_path)
|
||||
|
||||
assert (save_path).exists()
|
||||
@@ -18,15 +18,15 @@ class TestPrimaiteSession:
|
||||
raise AssertionError
|
||||
|
||||
assert session is not None
|
||||
assert session.simulation
|
||||
assert len(session.agents) == 3
|
||||
assert len(session.rl_agents) == 1
|
||||
assert session.game.simulation
|
||||
assert len(session.game.agents) == 3
|
||||
assert len(session.game.rl_agents) == 1
|
||||
|
||||
assert session.policy
|
||||
assert session.env
|
||||
|
||||
assert session.simulation.network
|
||||
assert len(session.simulation.network.nodes) == 10
|
||||
assert session.game.simulation.network
|
||||
assert len(session.game.simulation.network.nodes) == 10
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_start_session(self, temp_primaite_session):
|
||||
|
||||
Reference in New Issue
Block a user