#917 - Integrated the PrimaiteSession into all tests.
- Ran a full pre-commit hook and thus encountered tons of fixes required
This commit is contained in:
61
tests/test_primaite_session.py
Normal file
61
tests/test_primaite_session.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[[main_training_config_path(), dos_very_basic_config_path()]],
|
||||
indirect=True,
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
"""Tests the PrimaiteSession class and its outputs."""
|
||||
with temp_primaite_session as session:
|
||||
session_path = session.session_path
|
||||
assert session_path.exists()
|
||||
session.learn()
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.evaluate()
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
|
||||
# If you need to inspect any session outputs, it must be done inside
|
||||
# the context manager
|
||||
|
||||
# Check that the metadata json file exists
|
||||
assert (session_path / "session_metadata.json").exists()
|
||||
|
||||
# Check that the network png file exists
|
||||
assert (session_path / f"network_{session.timestamp_str}.png").exists()
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.evaluation_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
for file in file_names:
|
||||
path = os.path.join(dir_path, file)
|
||||
file_str = path.split(str(session_path))[-1]
|
||||
_LOGGER.debug(f" {file_str}")
|
||||
|
||||
# Now that we've exited the context manager, the session.session_path
|
||||
# directory and its contents are deleted
|
||||
assert not session_path.exists()
|
||||
Reference in New Issue
Block a user