2025-01-02 15:05:06 +00:00
|
|
|
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
2023-11-23 01:40:27 +00:00
|
|
|
"""Test that we can create a primaite environment and train sb3 agent with no crash."""
|
|
|
|
|
import tempfile
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
2023-11-29 01:28:40 +00:00
|
|
|
import pytest
|
2023-11-23 01:40:27 +00:00
|
|
|
import yaml
|
|
|
|
|
from stable_baselines3 import PPO
|
|
|
|
|
|
2024-03-04 19:43:51 +00:00
|
|
|
from primaite.config.load import data_manipulation_config_path
|
2023-11-23 01:40:27 +00:00
|
|
|
from primaite.game.game import PrimaiteGame
|
|
|
|
|
from primaite.session.environment import PrimaiteGymEnv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sb3_compatibility():
|
|
|
|
|
"""Test that the Gymnasium environment can be used with an SB3 agent."""
|
2024-03-04 19:43:51 +00:00
|
|
|
with open(data_manipulation_config_path(), "r") as f:
|
2023-11-23 01:40:27 +00:00
|
|
|
cfg = yaml.safe_load(f)
|
|
|
|
|
|
2024-04-25 15:09:46 +01:00
|
|
|
gym = PrimaiteGymEnv(env_config=cfg)
|
2023-11-23 01:40:27 +00:00
|
|
|
model = PPO("MlpPolicy", gym)
|
|
|
|
|
|
2024-07-12 11:23:41 +01:00
|
|
|
model.learn(total_timesteps=256)
|
2023-11-23 01:40:27 +00:00
|
|
|
|
|
|
|
|
save_path = Path(tempfile.gettempdir()) / "model.zip"
|
|
|
|
|
model.save(save_path)
|
|
|
|
|
|
|
|
|
|
assert (save_path).exists()
|
2024-05-01 14:33:33 +01:00
|
|
|
save_path.unlink() # clean up
|