Files
PrimAITE/tests/test_session_loading.py

188 lines
6.2 KiB
Python

# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import os.path
import shutil
import tempfile
from pathlib import Path
from typing import Union
from uuid import uuid4
from typer.testing import CliRunner
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.cli import app
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.main import run
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from tests import TEST_ASSETS_ROOT
_LOGGER = getLogger(__name__)
runner = CliRunner()
sb3_expected_avg_reward_per_episode = {
10: 0.0,
11: -0.0011074218750000008,
12: -0.0010000000000000007,
13: -0.0016601562500000013,
14: -0.001400390625000001,
15: -0.0009863281250000007,
16: -0.0011855468750000008,
17: -0.0009511718750000007,
18: -0.0008789062500000007,
19: -0.0012226562500000009,
20: -0.0010292968750000007,
}
sb3_expected_eval_rewards = -0.0018515625000000014
def copy_session_asset(asset_path: Union[str, Path]) -> str:
"""Copies the asset into a temporary test folder."""
if asset_path is None:
raise Exception("No path provided")
if isinstance(asset_path, Path):
asset_path = str(os.path.normpath(asset_path))
copy_path = str(Path(tempfile.gettempdir()) / "primaite" / str(uuid4()))
# copy the asset into a temp path
try:
shutil.copytree(asset_path, copy_path)
except Exception as e:
msg = f"Unable to copy directory: {asset_path}"
_LOGGER.error(msg, e)
print(msg, e)
_LOGGER.debug(f"Copied test asset to: {copy_path}")
# return the copied assets path
return copy_path
def test_load_sb3_session():
"""Test that loading an SB3 agent works."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
loaded_agent = SB3Agent(session_path=test_path)
# loaded agent should have the same UUID as the previous agent
assert loaded_agent.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
assert loaded_agent._training_config.agent_framework == AgentFramework.SB3.name
assert loaded_agent._training_config.agent_identifier == AgentIdentifier.PPO.name
assert loaded_agent._training_config.deterministic
assert loaded_agent._training_config.seed == 12345
assert str(loaded_agent.session_path) == str(test_path)
# run another learn session
loaded_agent.learn()
learn_mean_rewards = av_rewards_dict(
loaded_agent.learning_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# run an evaluation
loaded_agent.evaluate()
# load the evaluation average reward csv file
eval_mean_reward = av_rewards_dict(
loaded_agent.evaluation_path / f"average_reward_per_episode_{loaded_agent.timestamp_str}.csv"
)
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
# delete the test directory
shutil.rmtree(test_path)
def test_load_primaite_session():
"""Test that loading a Primaite session works."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
session = PrimaiteSession(session_path=test_path)
# run setup on session
session.setup()
# make sure that the session was loaded correctly
assert session._agent_session.uuid == "301874d3-2e14-43c2-ba7f-e2b03ad05dde"
assert session._agent_session._training_config.agent_framework == AgentFramework.SB3.name
assert session._agent_session._training_config.agent_identifier == AgentIdentifier.PPO.name
assert session._agent_session._training_config.deterministic
assert session._agent_session._training_config.seed == 12345
assert str(session._agent_session.session_path) == str(test_path)
# run another learn session
session.learn()
learn_mean_rewards = av_rewards_dict(
session.learning_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# run an evaluation
session.evaluate()
# load the evaluation average reward csv file
eval_mean_reward = av_rewards_dict(
session.evaluation_path / f"average_reward_per_episode_{session.timestamp_str}.csv"
)
# the agent config ran the evaluation in deterministic mode, so should have the same reward value
assert len(set(eval_mean_reward.values())) == 1
# the evaluation should be the same as a previous run
assert next(iter(set(eval_mean_reward.values()))) == sb3_expected_eval_rewards
# delete the test directory
shutil.rmtree(test_path)
def test_run_loading():
"""Test loading session via main.run."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
run(session_path=test_path)
learn_mean_rewards = av_rewards_dict(
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# delete the test directory
shutil.rmtree(test_path)
def test_cli():
"""Test loading session via CLI."""
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
result = runner.invoke(app, ["session", "--load", test_path])
# cli should work
assert result.exit_code == 0
learn_mean_rewards = av_rewards_dict(
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == sb3_expected_avg_reward_per_episode
# delete the test directory
shutil.rmtree(test_path)