58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
"""
|
|
Seed tests.
|
|
|
|
These tests will train an agent.
|
|
This agent is then loaded and evaluated twice,
|
|
the 2 evaluation wuns should be the same.
|
|
|
|
This proves that the seed works.
|
|
"""
|
|
import time
|
|
|
|
from primaite.config.lay_down_config import dos_very_basic_config_path
|
|
from primaite.primaite_session import PrimaiteSession
|
|
from tests import TEST_CONFIG_ROOT
|
|
|
|
|
|
def test_seeded_sessions():
|
|
"""Test to see if seed works in multiple sessions."""
|
|
# ppo training session
|
|
ppo_train = PrimaiteSession(TEST_CONFIG_ROOT / "e2e/ppo_seeded_training_config.yaml", dos_very_basic_config_path())
|
|
# train agent
|
|
ppo_train.setup()
|
|
ppo_train.learn()
|
|
ppo_train.close()
|
|
|
|
# agent path to use for evaluation
|
|
path_prefix = f"{ppo_train._training_config.agent_framework}_{ppo_train._training_config.agent_identifier}"
|
|
agent_path = ppo_train.session_path / f"{path_prefix}_{ppo_train.timestamp_str}.zip"
|
|
|
|
ppo_session_1 = PrimaiteSession(
|
|
TEST_CONFIG_ROOT / "e2e/ppo_seeded_training_config.yaml", dos_very_basic_config_path()
|
|
)
|
|
|
|
# load trained agent
|
|
ppo_session_1._training_config.agent_load_file = agent_path
|
|
ppo_session_1.setup()
|
|
time.sleep(1)
|
|
|
|
ppo_session_2 = PrimaiteSession(
|
|
TEST_CONFIG_ROOT / "e2e/ppo_seeded_training_config.yaml", dos_very_basic_config_path()
|
|
)
|
|
|
|
# load trained agent
|
|
ppo_session_2._training_config.agent_load_file = agent_path
|
|
ppo_session_2.setup()
|
|
|
|
# run evaluation
|
|
ppo_session_1.evaluate()
|
|
ppo_session_1.close()
|
|
ppo_session_2.evaluate()
|
|
ppo_session_2.close()
|
|
|
|
# compare output
|
|
# assert compare_transaction_file(
|
|
# ppo_session_1.evaluation_path / f"all_transactions_{ppo_session_1.timestamp_str}.csv",
|
|
# ppo_session_2.evaluation_path / f"all_transactions_{ppo_session_2.timestamp_str}.csv"
|
|
# ) is True
|