Merged PR 459: fix test problems and slowness

Related work items: #2747
This commit is contained in:
Marek Wolan
2024-07-12 11:49:20 +00:00
6 changed files with 13 additions and 28 deletions

View File

@@ -179,18 +179,18 @@ class RequestManager(BaseModel):
requests = []
for req_name, req in self.request_types.items():
if isinstance(req.func, RequestManager):
sub_requests = req.func.get_request_types_recursively() # recurse
sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf
sub_requests = req.func.get_request_types_recursively()
sub_requests = [[req_name] + a for a in sub_requests]
requests.extend(sub_requests)
else: # leaf node found
requests.append(req_name)
else:
requests.append([req_name])
return requests
def show(self) -> None:
"""Display all currently available requests and whether they are valid."""
table = PrettyTable(["request"])
"""Display all currently available requests."""
table = PrettyTable(["requests"])
table.align = "l"
table.add_rows(self.get_request_types_recursively())
table.add_rows([[x] for x in self.get_request_types_recursively()])
print(table)
def check_valid(self, request: RequestFormat, context: Dict) -> bool:

View File

@@ -2,8 +2,8 @@
from typing import Any, Dict, Tuple
import pytest
import ray
import yaml
from ray import init as rayinit
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
@@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
ray.init(local_mode=True)
rayinit(local_mode=True)
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1

View File

@@ -1,9 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import importlib
from typing import Dict
import yaml
from ray import air, init, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
@@ -43,7 +41,7 @@ def test_sb3_action_masking(monkeypatch):
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
model.learn(512)
model.learn(256)
assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled

View File

@@ -1,7 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import ray
import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.session.ray_envs import PrimaiteRayMARLEnv
@@ -12,7 +10,6 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(MULTI_AGENT_PATH, "r") as f:
cfg = yaml.safe_load(f)
@@ -26,14 +23,5 @@ def test_rllib_multi_agent_compatibility():
)
.training(train_batch_size=128)
)
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": 128},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()
algo = config.build()
algo.train()

View File

@@ -3,7 +3,6 @@ import tempfile
from pathlib import Path
import pytest
import ray
import yaml
from ray.rllib.algorithms import ppo

View File

@@ -20,7 +20,7 @@ def test_sb3_compatibility():
gym = PrimaiteGymEnv(env_config=cfg)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)
model.learn(total_timesteps=256)
save_path = Path(tempfile.gettempdir()) / "model.zip"
model.save(save_path)