#1595: added loading sessions to run command + test + documentation for how to use loading sessions

This commit is contained in:
Czar Echavez
2023-07-14 15:51:38 +01:00
parent 8e2f105d57
commit 7c2ff55da2
4 changed files with 90 additions and 13 deletions

View File

@@ -42,6 +42,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to: For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``. ``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
Loading a session
-------
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
.. tabs::
.. code-tab:: bash
:caption: Unix CLI
cd ~/primaite
source ./.venv/bin/activate
primaite session --load "path/to/session"
.. code-tab:: bash
:caption: Powershell CLI
cd ~\primaite
.\.venv\Scripts\activate
primaite session --load "path\to\session"
.. code-tab:: python
:caption: Python
from primaite.main import run
run(session_path=<previous session directory>)
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
Outputs Outputs
------- -------

View File

@@ -2,7 +2,7 @@
"""The main PrimAITE session runner module.""" """The main PrimAITE session runner module."""
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Union from typing import Optional, Union
from primaite import getLogger from primaite import getLogger
from primaite.primaite_session import PrimaiteSession from primaite.primaite_session import PrimaiteSession
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
def run( def run(
training_config_path: Union[str, Path], training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Union[str, Path], lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
): ):
""" """
Run the PrimAITE Session. Run the PrimAITE Session.
:param training_config_path: The training config filepath. :param training_config_path: YAML file containing configurable items defined in
:param lay_down_config_path: The lay down config filepath. `primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
""" """
session = PrimaiteSession(training_config_path, lay_down_config_path) session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
session.setup() session.setup()
session.learn() session.learn()
@@ -31,9 +36,14 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--tc") parser.add_argument("--tc")
parser.add_argument("--ldc") parser.add_argument("--ldc")
parser.add_argument("--load")
args = parser.parse_args() args = parser.parse_args()
if not args.tc: if args.load:
_LOGGER.error("Please provide a training config file using the --tc " "argument") run(session_path=args.load)
if not args.ldc: else:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument") if not args.tc:
run(training_config_path=args.tc, lay_down_config_path=args.ldc) _LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -35,8 +35,12 @@ class PrimaiteSession:
""" """
The PrimaiteSession constructor. The PrimaiteSession constructor.
:param training_config_path: The training config path. :param training_config_path: YAML file containing configurable items defined in
:param lay_down_config_path: The lay down config path. `primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
""" """
self._agent_session: AgentSessionABC = None # noqa self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = session_path # noqa self.session_path: Path = session_path # noqa

View File

@@ -8,6 +8,7 @@ from uuid import uuid4
from primaite import getLogger from primaite import getLogger
from primaite.agents.sb3 import SB3Agent from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.main import run
from primaite.primaite_session import PrimaiteSession from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict from primaite.utils.session_output_reader import av_rewards_dict
from tests import TEST_ASSETS_ROOT from tests import TEST_ASSETS_ROOT
@@ -153,3 +154,35 @@ def test_load_primaite_session():
# delete the test directory # delete the test directory
shutil.rmtree(test_path) shutil.rmtree(test_path)
def test_run_loading():
"""Test loading session via main.run."""
expected_learn_mean_reward_per_episode = {
10: 0,
11: -0.008037109374999995,
12: -0.007978515624999988,
13: -0.008191406249999991,
14: -0.00817578124999999,
15: -0.008085937499999998,
16: -0.007837890624999982,
17: -0.007798828124999992,
18: -0.007777343749999998,
19: -0.007958984374999988,
20: -0.0077499999999999835,
}
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
# create loaded session
run(session_path=test_path)
learn_mean_rewards = av_rewards_dict(
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
)
# run is seeded so should have the expected learn value
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
# delete the test directory
shutil.rmtree(test_path)