fix test problems and slowness

This commit is contained in:
Marek Wolan
2024-07-12 11:23:41 +01:00
parent e759ae5990
commit 199cd0d9df
6 changed files with 13 additions and 28 deletions

View File

@@ -179,18 +179,18 @@ class RequestManager(BaseModel):
requests = [] requests = []
for req_name, req in self.request_types.items(): for req_name, req in self.request_types.items():
if isinstance(req.func, RequestManager): if isinstance(req.func, RequestManager):
sub_requests = req.func.get_request_types_recursively() # recurse sub_requests = req.func.get_request_types_recursively()
sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf sub_requests = [[req_name] + a for a in sub_requests]
requests.extend(sub_requests) requests.extend(sub_requests)
else: # leaf node found else:
requests.append(req_name) requests.append([req_name])
return requests return requests
def show(self) -> None: def show(self) -> None:
"""Display all currently available requests and whether they are valid.""" """Display all currently available requests."""
table = PrettyTable(["request"]) table = PrettyTable(["requests"])
table.align = "l" table.align = "l"
table.add_rows(self.get_request_types_recursively()) table.add_rows([[x] for x in self.get_request_types_recursively()])
print(table) print(table)
def check_valid(self, request: RequestFormat, context: Dict) -> bool: def check_valid(self, request: RequestFormat, context: Dict) -> bool:

View File

@@ -2,8 +2,8 @@
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import pytest import pytest
import ray
import yaml import yaml
from ray import init as rayinit
from primaite import getLogger, PRIMAITE_PATHS from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager from primaite.game.agent.actions import ActionManager
@@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT from tests import TEST_ASSETS_ROOT
ray.init(local_mode=True) rayinit(local_mode=True)
ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1

View File

@@ -1,9 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import importlib
from typing import Dict from typing import Dict
import yaml import yaml
from ray import air, init, tune
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
@@ -43,7 +41,7 @@ def test_sb3_action_masking(monkeypatch):
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action)) monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32) model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
model.learn(512) model.learn(256)
assert len(action_num_history) == len(mask_history) > 0 assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled # Make sure the masks had at least some False entries, if it was all True then the mask was disabled

View File

@@ -1,7 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import ray
import yaml import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from primaite.session.ray_envs import PrimaiteRayMARLEnv from primaite.session.ray_envs import PrimaiteRayMARLEnv
@@ -12,7 +10,6 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
def test_rllib_multi_agent_compatibility(): def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system.""" """Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(MULTI_AGENT_PATH, "r") as f: with open(MULTI_AGENT_PATH, "r") as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
@@ -26,14 +23,5 @@ def test_rllib_multi_agent_compatibility():
) )
.training(train_batch_size=128) .training(train_batch_size=128)
) )
algo = config.build()
tune.Tuner( algo.train()
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": 128},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()

View File

@@ -3,7 +3,6 @@ import tempfile
from pathlib import Path from pathlib import Path
import pytest import pytest
import ray
import yaml import yaml
from ray.rllib.algorithms import ppo from ray.rllib.algorithms import ppo

View File

@@ -20,7 +20,7 @@ def test_sb3_compatibility():
gym = PrimaiteGymEnv(env_config=cfg) gym = PrimaiteGymEnv(env_config=cfg)
model = PPO("MlpPolicy", gym) model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000) model.learn(total_timesteps=256)
save_path = Path(tempfile.gettempdir()) / "model.zip" save_path = Path(tempfile.gettempdir()) / "model.zip"
model.save(save_path) model.save(save_path)