96 lines
4.4 KiB
Python
96 lines
4.4 KiB
Python
import pydantic
|
|
import pytest
|
|
|
|
from tests import TEST_ASSETS_ROOT
|
|
from tests.conftest import TempPrimaiteSession
|
|
|
|
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
|
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
|
|
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
|
|
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
|
|
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
|
|
|
|
|
class TestPrimaiteSession:
|
|
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
|
def test_creating_session(self, temp_primaite_session):
|
|
"""Check that creating a session from config works."""
|
|
with temp_primaite_session as session:
|
|
if not isinstance(session, TempPrimaiteSession):
|
|
raise AssertionError
|
|
|
|
assert session is not None
|
|
assert session.game.simulation
|
|
assert len(session.game.agents) == 3
|
|
assert len(session.game.rl_agents) == 1
|
|
|
|
assert session.policy
|
|
assert session.env
|
|
|
|
assert session.game.simulation.network
|
|
assert len(session.game.simulation.network.nodes) == 10
|
|
|
|
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
|
def test_start_session(self, temp_primaite_session):
|
|
"""Make sure you can go all the way through the session without errors."""
|
|
with temp_primaite_session as session:
|
|
session: TempPrimaiteSession
|
|
session.start_session()
|
|
|
|
session_path = session.io_manager.session_path
|
|
assert session_path.exists()
|
|
print(list(session_path.glob("*")))
|
|
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
|
|
assert checkpoint_dir.exists()
|
|
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
|
|
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
|
|
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
|
|
assert checkpoint_1.exists()
|
|
assert checkpoint_2.exists()
|
|
assert not checkpoint_3.exists()
|
|
|
|
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
|
|
def test_training_only_session(self, temp_primaite_session):
|
|
"""Check that you can run a training-only session."""
|
|
with temp_primaite_session as session:
|
|
session: TempPrimaiteSession
|
|
session.start_session()
|
|
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
|
|
|
|
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
|
|
def test_eval_only_session(self, temp_primaite_session):
|
|
"""Check that you can load a model and run an eval-only session."""
|
|
with temp_primaite_session as session:
|
|
session: TempPrimaiteSession
|
|
session.start_session()
|
|
# TODO: include checks that the model was loaded and that the eval-only session ran
|
|
|
|
@pytest.mark.skip(reason="Slow, reenable later")
|
|
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
|
|
def test_multi_agent_session(self, temp_primaite_session):
|
|
"""Check that we can run a training session with a multi agent system."""
|
|
with temp_primaite_session as session:
|
|
session.start_session()
|
|
|
|
def test_error_thrown_on_bad_configuration(self):
|
|
with pytest.raises(pydantic.ValidationError):
|
|
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
|
|
|
|
@pytest.mark.skip(
|
|
reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, "
|
|
"reset doesn't implement software restore."
|
|
)
|
|
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
|
def test_session_sim_reset(self, temp_primaite_session):
|
|
with temp_primaite_session as session:
|
|
session: TempPrimaiteSession
|
|
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
|
|
client_1.software_manager.uninstall("DataManipulationBot")
|
|
|
|
assert "DataManipulationBot" not in client_1.software_manager.software
|
|
|
|
session.game.reset()
|
|
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
|
|
|
|
assert "DataManipulationBot" in client_1.software_manager.software
|