Merged PR 459: fix test problems and slowness
Related work items: #2747
This commit is contained in:
@@ -179,18 +179,18 @@ class RequestManager(BaseModel):
|
||||
requests = []
|
||||
for req_name, req in self.request_types.items():
|
||||
if isinstance(req.func, RequestManager):
|
||||
sub_requests = req.func.get_request_types_recursively() # recurse
|
||||
sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf
|
||||
sub_requests = req.func.get_request_types_recursively()
|
||||
sub_requests = [[req_name] + a for a in sub_requests]
|
||||
requests.extend(sub_requests)
|
||||
else: # leaf node found
|
||||
requests.append(req_name)
|
||||
else:
|
||||
requests.append([req_name])
|
||||
return requests
|
||||
|
||||
def show(self) -> None:
|
||||
"""Display all currently available requests and whether they are valid."""
|
||||
table = PrettyTable(["request"])
|
||||
"""Display all currently available requests."""
|
||||
table = PrettyTable(["requests"])
|
||||
table.align = "l"
|
||||
table.add_rows(self.get_request_types_recursively())
|
||||
table.add_rows([[x] for x in self.get_request_types_recursively()])
|
||||
print(table)
|
||||
|
||||
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import yaml
|
||||
from ray import init as rayinit
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
@@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
ray.init(local_mode=True)
|
||||
rayinit(local_mode=True)
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
from ray import air, init, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
|
||||
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
|
||||
@@ -43,7 +41,7 @@ def test_sb3_action_masking(monkeypatch):
|
||||
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
|
||||
|
||||
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
|
||||
model.learn(512)
|
||||
model.learn(256)
|
||||
|
||||
assert len(action_num_history) == len(mask_history) > 0
|
||||
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import ray
|
||||
import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
@@ -12,7 +10,6 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
def test_rllib_multi_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
||||
|
||||
with open(MULTI_AGENT_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
@@ -26,14 +23,5 @@ def test_rllib_multi_agent_compatibility():
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": 128},
|
||||
checkpoint_config=air.CheckpointConfig(
|
||||
checkpoint_frequency=10,
|
||||
),
|
||||
),
|
||||
param_space=config,
|
||||
).fit()
|
||||
algo = config.build()
|
||||
algo.train()
|
||||
|
||||
@@ -3,7 +3,6 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import yaml
|
||||
from ray.rllib.algorithms import ppo
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_sb3_compatibility():
|
||||
gym = PrimaiteGymEnv(env_config=cfg)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
model.learn(total_timesteps=256)
|
||||
|
||||
save_path = Path(tempfile.gettempdir()) / "model.zip"
|
||||
model.save(save_path)
|
||||
|
||||
Reference in New Issue
Block a user