fix test problems and slowness
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import yaml
|
||||
from ray import init as rayinit
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
@@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
ray.init(local_mode=True)
|
||||
rayinit(local_mode=True)
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
from ray import air, init, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
|
||||
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
|
||||
@@ -43,7 +41,7 @@ def test_sb3_action_masking(monkeypatch):
|
||||
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
|
||||
|
||||
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
|
||||
model.learn(512)
|
||||
model.learn(256)
|
||||
|
||||
assert len(action_num_history) == len(mask_history) > 0
|
||||
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import ray
|
||||
import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
@@ -12,7 +10,6 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
def test_rllib_multi_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
||||
|
||||
with open(MULTI_AGENT_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
@@ -26,14 +23,5 @@ def test_rllib_multi_agent_compatibility():
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": 128},
|
||||
checkpoint_config=air.CheckpointConfig(
|
||||
checkpoint_frequency=10,
|
||||
),
|
||||
),
|
||||
param_space=config,
|
||||
).fit()
|
||||
algo = config.build()
|
||||
algo.train()
|
||||
|
||||
@@ -3,7 +3,6 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import yaml
|
||||
from ray.rllib.algorithms import ppo
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_sb3_compatibility():
|
||||
gym = PrimaiteGymEnv(env_config=cfg)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
model.learn(total_timesteps=256)
|
||||
|
||||
save_path = Path(tempfile.gettempdir()) / "model.zip"
|
||||
model.save(save_path)
|
||||
|
||||
Reference in New Issue
Block a user