diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index a5e39cc8..848570fe 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -179,18 +179,18 @@ class RequestManager(BaseModel): requests = [] for req_name, req in self.request_types.items(): if isinstance(req.func, RequestManager): - sub_requests = req.func.get_request_types_recursively() # recurse - sub_requests = [([req_name] + a) for a in sub_requests] # prepend parent request to leaf + sub_requests = req.func.get_request_types_recursively() + sub_requests = [[req_name] + a for a in sub_requests] requests.extend(sub_requests) - else: # leaf node found - requests.append(req_name) + else: + requests.append([req_name]) return requests def show(self) -> None: - """Display all currently available requests and whether they are valid.""" - table = PrettyTable(["request"]) + """Display all currently available requests.""" + table = PrettyTable(["requests"]) table.align = "l" - table.add_rows(self.get_request_types_recursively()) + table.add_rows([[x] for x in self.get_request_types_recursively()]) print(table) def check_valid(self, request: RequestFormat, context: Dict) -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index b8b50182..54519e2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,8 @@ from typing import Any, Dict, Tuple import pytest -import ray import yaml +from ray import init as rayinit from primaite import getLogger, PRIMAITE_PATHS from primaite.game.agent.actions import ActionManager @@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT -ray.init(local_mode=True) +rayinit(local_mode=True) ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index a299b913..745e280b 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -1,9 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -import importlib from typing import Dict import yaml -from ray import air, init, tune from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec @@ -43,7 +41,7 @@ def test_sb3_action_masking(monkeypatch): monkeypatch.setattr(env, "step", lambda action: cache_step(env, action)) model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32) - model.learn(512) + model.learn(256) assert len(action_num_history) == len(mask_history) > 0 # Make sure the masks had at least some False entries, if it was all True then the mask was disabled diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index e015c33c..26e690d0 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -1,7 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -import ray import yaml -from ray import air, tune from ray.rllib.algorithms.ppo import PPOConfig from primaite.session.ray_envs import PrimaiteRayMARLEnv @@ -12,7 +10,6 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" def test_rllib_multi_agent_compatibility(): """Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system.""" - with open(MULTI_AGENT_PATH, "r") as f: cfg = yaml.safe_load(f) @@ -26,14 +23,5 @@ def test_rllib_multi_agent_compatibility(): ) .training(train_batch_size=128) ) - - tune.Tuner( - "PPO", - run_config=air.RunConfig( - stop={"training_iteration": 128}, - checkpoint_config=air.CheckpointConfig( - checkpoint_frequency=10, - ), - ), - param_space=config, - ).fit() + algo = config.build() + algo.train() diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index a02a078c..265257e4 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -3,7 +3,6 @@ import tempfile from pathlib import Path import pytest -import ray import yaml from ray.rllib.algorithms import ppo diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 27fb134b..a07d5d2e 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -20,7 +20,7 @@ def test_sb3_compatibility(): gym = PrimaiteGymEnv(env_config=cfg) model = PPO("MlpPolicy", gym) - model.learn(total_timesteps=1000) + model.learn(total_timesteps=256) save_path = Path(tempfile.gettempdir()) / "model.zip" model.save(save_path)