diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index a393093c..8079212e 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -42,6 +42,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions//) + +When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory Outputs ------- diff --git a/src/primaite/main.py b/src/primaite/main.py index f2d1b9c2..9fcc4df6 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -2,7 +2,7 @@ """The main PrimAITE session runner module.""" import argparse from pathlib import Path -from typing import Union +from typing import Optional, Union from primaite import getLogger from primaite.primaite_session import PrimaiteSession @@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__) def run( - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], + training_config_path: Optional[Union[str, Path]] = "", + lay_down_config_path: Optional[Union[str, Path]] = "", + session_path: Optional[Union[str, Path]] = None, ): """ Run the PrimAITE Session. - :param training_config_path: The training config filepath. - :param lay_down_config_path: The lay down config filepath. + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ - session = PrimaiteSession(training_config_path, lay_down_config_path) + session = PrimaiteSession(training_config_path, lay_down_config_path, session_path) session.setup() session.learn() @@ -31,9 +36,14 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tc") parser.add_argument("--ldc") + parser.add_argument("--load") + args = parser.parse_args() - if not args.tc: - _LOGGER.error("Please provide a training config file using the --tc " "argument") - if not args.ldc: - _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") - run(training_config_path=args.tc, lay_down_config_path=args.ldc) + if args.load: + run(session_path=args.load) + else: + if not args.tc: + _LOGGER.error("Please provide a training config file using the --tc " "argument") + if not args.ldc: + _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") + run(training_config_path=args.tc, lay_down_config_path=args.ldc) diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 4dab5cb6..76134238 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -35,8 +35,12 @@ class PrimaiteSession: """ The PrimaiteSession constructor. - :param training_config_path: The training config path. - :param lay_down_config_path: The lay down config path. + :param training_config_path: YAML file containing configurable items defined in + `primaite.config.training_config.TrainingConfig` + :type training_config_path: Union[path, str] + :param lay_down_config_path: YAML file containing configurable items for generating network laydown. + :type lay_down_config_path: Union[path, str] + :param session_path: directory path of the session to load """ self._agent_session: AgentSessionABC = None # noqa self.session_path: Path = session_path # noqa diff --git a/tests/test_session_loading.py b/tests/test_session_loading.py index d79b0dde..54cac351 100644 --- a/tests/test_session_loading.py +++ b/tests/test_session_loading.py @@ -8,6 +8,7 @@ from uuid import uuid4 from primaite import getLogger from primaite.agents.sb3 import SB3Agent from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.main import run from primaite.primaite_session import PrimaiteSession from primaite.utils.session_output_reader import av_rewards_dict from tests import TEST_ASSETS_ROOT @@ -153,3 +154,35 @@ def test_load_primaite_session(): # delete the test directory shutil.rmtree(test_path) + + +def test_run_loading(): + """Test loading session via main.run.""" + expected_learn_mean_reward_per_episode = { + 10: 0, + 11: -0.008037109374999995, + 12: -0.007978515624999988, + 13: -0.008191406249999991, + 14: -0.00817578124999999, + 15: -0.008085937499999998, + 16: -0.007837890624999982, + 17: -0.007798828124999992, + 18: -0.007777343749999998, + 19: -0.007958984374999988, + 20: -0.0077499999999999835, + } + + test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session") + + # create loaded session + run(session_path=test_path) + + learn_mean_rewards = av_rewards_dict( + next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None) + ) + + # run is seeded so should have the expected learn value + assert learn_mean_rewards == expected_learn_mean_reward_per_episode + + # delete the test directory + shutil.rmtree(test_path)