#1595: added loading sessions to run command + test + documentation for how to use loading sessions
This commit is contained in:
@@ -42,6 +42,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
|
||||
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
|
||||
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
|
||||
|
||||
Loading a session
|
||||
-------
|
||||
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
|
||||
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
|
||||
|
||||
.. tabs::
|
||||
|
||||
.. code-tab:: bash
|
||||
:caption: Unix CLI
|
||||
|
||||
cd ~/primaite
|
||||
source ./.venv/bin/activate
|
||||
primaite session --load "path/to/session"
|
||||
|
||||
.. code-tab:: bash
|
||||
:caption: Powershell CLI
|
||||
|
||||
cd ~\primaite
|
||||
.\.venv\Scripts\activate
|
||||
primaite session --load "path\to\session"
|
||||
|
||||
|
||||
.. code-tab:: python
|
||||
:caption: Python
|
||||
|
||||
from primaite.main import run
|
||||
|
||||
run(session_path=<previous session directory>)
|
||||
|
||||
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
|
||||
|
||||
Outputs
|
||||
-------
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
@@ -31,9 +36,14 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
parser.add_argument("--load")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
if args.load:
|
||||
run(session_path=args.load)
|
||||
else:
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
@@ -35,8 +35,12 @@ class PrimaiteSession:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: The training config path.
|
||||
:param lay_down_config_path: The lay down config path.
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import uuid4
|
||||
from primaite import getLogger
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.main import run
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.utils.session_output_reader import av_rewards_dict
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
@@ -153,3 +154,35 @@ def test_load_primaite_session():
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
|
||||
def test_run_loading():
|
||||
"""Test loading session via main.run."""
|
||||
expected_learn_mean_reward_per_episode = {
|
||||
10: 0,
|
||||
11: -0.008037109374999995,
|
||||
12: -0.007978515624999988,
|
||||
13: -0.008191406249999991,
|
||||
14: -0.00817578124999999,
|
||||
15: -0.008085937499999998,
|
||||
16: -0.007837890624999982,
|
||||
17: -0.007798828124999992,
|
||||
18: -0.007777343749999998,
|
||||
19: -0.007958984374999988,
|
||||
20: -0.0077499999999999835,
|
||||
}
|
||||
|
||||
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||
|
||||
# create loaded session
|
||||
run(session_path=test_path)
|
||||
|
||||
learn_mean_rewards = av_rewards_dict(
|
||||
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
|
||||
)
|
||||
|
||||
# run is seeded so should have the expected learn value
|
||||
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||
|
||||
# delete the test directory
|
||||
shutil.rmtree(test_path)
|
||||
|
||||
Reference in New Issue
Block a user