#1595: added loading sessions to run command + test + documentation for how to use loading sessions

This commit is contained in:
Czar Echavez
2023-07-14 15:51:38 +01:00
parent 8e2f105d57
commit 7c2ff55da2
4 changed files with 90 additions and 13 deletions

View File

@@ -42,6 +42,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
Loading a session
-------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Outputs
-------

View File

@@ -2,7 +2,7 @@
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Union
from typing import Optional, Union
from primaite import getLogger
from primaite.primaite_session import PrimaiteSession
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
def run(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
"""
Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
session = PrimaiteSession(training_config_path, lay_down_config_path)
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
session.setup()
session.learn()
@@ -31,7 +36,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
parser.add_argument("--load")
args = parser.parse_args()
if args.load:
run(session_path=args.load)
else:
if not args.tc:
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:

View File

@@ -35,8 +35,12 @@ class PrimaiteSession:
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
"""
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = session_path # noqa

View File

@@ -8,6 +8,7 @@ from uuid import uuid4
from primaite import getLogger
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.main import run
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
from tests import TEST_ASSETS_ROOT
@@ -153,3 +154,35 @@ def test_load_primaite_session():
# delete the test directory
shutil.rmtree(test_path)
def test_run_loading():
"""Test loading session via main.run."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
run(session_path=test_path)
learn_mean_rewards = av_rewards_dict(
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
# delete the test directory
shutil.rmtree(test_path)