fix test problems and slowness
This commit is contained in:
@@ -179,18 +179,18 @@ class RequestManager(BaseModel):
|
|||||||
requests = []
|
requests = []
|
||||||
for req_name, req in self.request_types.items():
|
for req_name, req in self.request_types.items():
|
||||||
if isinstance(req.func, RequestManager):
|
if isinstance(req.func, RequestManager):
|
||||||
sub_requests = req.func.get_request_types_recursively() # recurse
|
sub_requests = req.func.get_request_types_recursively()
|
||||||
sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf
|
sub_requests = [[req_name] + a for a in sub_requests]
|
||||||
requests.extend(sub_requests)
|
requests.extend(sub_requests)
|
||||||
else: # leaf node found
|
else:
|
||||||
requests.append(req_name)
|
requests.append([req_name])
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
def show(self) -> None:
|
def show(self) -> None:
|
||||||
"""Display all currently available requests and whether they are valid."""
|
"""Display all currently available requests."""
|
||||||
table = PrettyTable(["request"])
|
table = PrettyTable(["requests"])
|
||||||
table.align = "l"
|
table.align = "l"
|
||||||
table.add_rows(self.get_request_types_recursively())
|
table.add_rows([[x] for x in self.get_request_types_recursively()])
|
||||||
print(table)
|
print(table)
|
||||||
|
|
||||||
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
|
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
from typing import Any, Dict, Tuple
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from ray import init as rayinit
|
||||||
|
|
||||||
from primaite import getLogger, PRIMAITE_PATHS
|
from primaite import getLogger, PRIMAITE_PATHS
|
||||||
from primaite.game.agent.actions import ActionManager
|
from primaite.game.agent.actions import ActionManager
|
||||||
@@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service
|
|||||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||||
from tests import TEST_ASSETS_ROOT
|
from tests import TEST_ASSETS_ROOT
|
||||||
|
|
||||||
ray.init(local_mode=True)
|
rayinit(local_mode=True)
|
||||||
ACTION_SPACE_NODE_VALUES = 1
|
ACTION_SPACE_NODE_VALUES = 1
|
||||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||||
import importlib
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from ray import air, init, tune
|
|
||||||
from ray.rllib.algorithms.ppo import PPOConfig
|
from ray.rllib.algorithms.ppo import PPOConfig
|
||||||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
|
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
|
||||||
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
|
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
|
||||||
@@ -43,7 +41,7 @@ def test_sb3_action_masking(monkeypatch):
|
|||||||
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
|
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
|
||||||
|
|
||||||
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
|
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
|
||||||
model.learn(512)
|
model.learn(256)
|
||||||
|
|
||||||
assert len(action_num_history) == len(mask_history) > 0
|
assert len(action_num_history) == len(mask_history) > 0
|
||||||
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
|
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||||
import ray
|
|
||||||
import yaml
|
import yaml
|
||||||
from ray import air, tune
|
|
||||||
from ray.rllib.algorithms.ppo import PPOConfig
|
from ray.rllib.algorithms.ppo import PPOConfig
|
||||||
|
|
||||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||||
@@ -12,7 +10,6 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
|||||||
|
|
||||||
def test_rllib_multi_agent_compatibility():
|
def test_rllib_multi_agent_compatibility():
|
||||||
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
||||||
|
|
||||||
with open(MULTI_AGENT_PATH, "r") as f:
|
with open(MULTI_AGENT_PATH, "r") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
@@ -26,14 +23,5 @@ def test_rllib_multi_agent_compatibility():
|
|||||||
)
|
)
|
||||||
.training(train_batch_size=128)
|
.training(train_batch_size=128)
|
||||||
)
|
)
|
||||||
|
algo = config.build()
|
||||||
tune.Tuner(
|
algo.train()
|
||||||
"PPO",
|
|
||||||
run_config=air.RunConfig(
|
|
||||||
stop={"training_iteration": 128},
|
|
||||||
checkpoint_config=air.CheckpointConfig(
|
|
||||||
checkpoint_frequency=10,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
param_space=config,
|
|
||||||
).fit()
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
|
||||||
import yaml
|
import yaml
|
||||||
from ray.rllib.algorithms import ppo
|
from ray.rllib.algorithms import ppo
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ def test_sb3_compatibility():
|
|||||||
gym = PrimaiteGymEnv(env_config=cfg)
|
gym = PrimaiteGymEnv(env_config=cfg)
|
||||||
model = PPO("MlpPolicy", gym)
|
model = PPO("MlpPolicy", gym)
|
||||||
|
|
||||||
model.learn(total_timesteps=1000)
|
model.learn(total_timesteps=256)
|
||||||
|
|
||||||
save_path = Path(tempfile.gettempdir()) / "model.zip"
|
save_path = Path(tempfile.gettempdir()) / "model.zip"
|
||||||
model.save(save_path)
|
model.save(save_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user