Files
PrimAITE/tests/e2e_integration_tests/test_session_repeatability.py

58 lines
1.8 KiB
Python

"""
Seed tests.
These tests will train an agent.
This agent is then loaded and evaluated twice,
the 2 evaluation wuns should be the same.
This proves that the seed works.
"""
import time
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.primaite_session import PrimaiteSession
from tests import TEST_CONFIG_ROOT
def test_seeded_sessions():
"""Test to see if seed works in multiple sessions."""
# ppo training session
ppo_train = PrimaiteSession(TEST_CONFIG_ROOT / "e2e/ppo_seeded_training_config.yaml", dos_very_basic_config_path())
# train agent
ppo_train.setup()
ppo_train.learn()
ppo_train.close()
# agent path to use for evaluation
path_prefix = f"{ppo_train._training_config.agent_framework}_{ppo_train._training_config.agent_identifier}"
agent_path = ppo_train.session_path / f"{path_prefix}_{ppo_train.timestamp_str}.zip"
ppo_session_1 = PrimaiteSession(
TEST_CONFIG_ROOT / "e2e/ppo_seeded_training_config.yaml", dos_very_basic_config_path()
)
# load trained agent
ppo_session_1._training_config.agent_load_file = agent_path
ppo_session_1.setup()
time.sleep(1)
ppo_session_2 = PrimaiteSession(
TEST_CONFIG_ROOT / "e2e/ppo_seeded_training_config.yaml", dos_very_basic_config_path()
)
# load trained agent
ppo_session_2._training_config.agent_load_file = agent_path
ppo_session_2.setup()
# run evaluation
ppo_session_1.evaluate()
ppo_session_1.close()
ppo_session_2.evaluate()
ppo_session_2.close()
# compare output
# assert compare_transaction_file(
# ppo_session_1.evaluation_path / f"all_transactions_{ppo_session_1.timestamp_str}.csv",
# ppo_session_2.evaluation_path / f"all_transactions_{ppo_session_2.timestamp_str}.csv"
# ) is True