Update end to end tests after session changes

This commit is contained in:
Marek Wolan
2023-11-23 01:40:27 +00:00
parent 14ae8be5e2
commit 8a2279c6cb
10 changed files with 2099 additions and 2001 deletions

View File

@@ -96,17 +96,6 @@
"source": [
"algo.save(\"temp/deleteme\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.config.load import example_config_path\n",
"from primaite.main import run\n",
"run(example_config_path())"
]
}
],
"metadata": {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -12,6 +12,7 @@ import yaml
from primaite import getLogger
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession
# from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession
@@ -74,7 +75,7 @@ def file_system() -> FileSystem:
# PrimAITE v2 stuff
class TempPrimaiteSession(PrimaiteGame):
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.

View File

@@ -0,0 +1,43 @@
import ray
import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayMARLEnv
def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": 128},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()

View File

@@ -0,0 +1,38 @@
import tempfile
from pathlib import Path
import ray
import yaml
from ray.rllib.algorithms import ppo
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayEnv
def test_rllib_single_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = {
"env": PrimaiteRayEnv,
"env_config": env_config,
"disable_env_checking": True,
"num_rollout_workers": 0,
}
algo = ppo.PPO(config=config)
for i in range(5):
result = algo.train()
save_file = Path(tempfile.gettempdir()) / "ray/"
algo.save(save_file)
assert save_file.exists()

View File

@@ -0,0 +1,27 @@
"""Test that we can create a primaite environment and train sb3 agent with no crash."""
import tempfile
from pathlib import Path
import yaml
from stable_baselines3 import PPO
from primaite.config.load import example_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
gym = PrimaiteGymEnv(game=game)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)
save_path = Path(tempfile.gettempdir()) / "model.zip"
model.save(save_path)
assert (save_path).exists()

View File

@@ -18,15 +18,15 @@ class TestPrimaiteSession:
raise AssertionError
assert session is not None
assert session.simulation
assert len(session.agents) == 3
assert len(session.rl_agents) == 1
assert session.game.simulation
assert len(session.game.agents) == 3
assert len(session.game.rl_agents) == 1
assert session.policy
assert session.env
assert session.simulation.network
assert len(session.simulation.network.nodes) == 10
assert session.game.simulation.network
assert len(session.game.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):