#1595: added loading sessions to run command + test + documentation for how to use loading sessions
This commit is contained in:
@@ -42,6 +42,36 @@ The sub-directory is formatted as such: ``~/primaite/sessions/<yyyy-mm-dd>/<yyyy
|
|||||||
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
|
For example, when running a session at 17:30:00 on 31st January 2023, the session will output to:
|
||||||
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
|
``~/primaite/sessions/2023-01-31/2023-01-31_17-30-00/``.
|
||||||
|
|
||||||
|
Loading a session
|
||||||
|
-------
|
||||||
|
A previous session can be loaded by providing the **directory** of the previous session to either the ``primaite session`` command from the cli
|
||||||
|
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` with session_path.
|
||||||
|
|
||||||
|
.. tabs::
|
||||||
|
|
||||||
|
.. code-tab:: bash
|
||||||
|
:caption: Unix CLI
|
||||||
|
|
||||||
|
cd ~/primaite
|
||||||
|
source ./.venv/bin/activate
|
||||||
|
primaite session --load "path/to/session"
|
||||||
|
|
||||||
|
.. code-tab:: bash
|
||||||
|
:caption: Powershell CLI
|
||||||
|
|
||||||
|
cd ~\primaite
|
||||||
|
.\.venv\Scripts\activate
|
||||||
|
primaite session --load "path\to\session"
|
||||||
|
|
||||||
|
|
||||||
|
.. code-tab:: python
|
||||||
|
:caption: Python
|
||||||
|
|
||||||
|
from primaite.main import run
|
||||||
|
|
||||||
|
run(session_path=<previous session directory>)
|
||||||
|
|
||||||
|
When PrimAITE runs a loaded session, PrimAITE will output in the provided session directory
|
||||||
|
|
||||||
Outputs
|
Outputs
|
||||||
-------
|
-------
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
"""The main PrimAITE session runner module."""
|
"""The main PrimAITE session runner module."""
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from primaite import getLogger
|
from primaite import getLogger
|
||||||
from primaite.primaite_session import PrimaiteSession
|
from primaite.primaite_session import PrimaiteSession
|
||||||
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
training_config_path: Union[str, Path],
|
training_config_path: Optional[Union[str, Path]] = "",
|
||||||
lay_down_config_path: Union[str, Path],
|
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||||
|
session_path: Optional[Union[str, Path]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run the PrimAITE Session.
|
Run the PrimAITE Session.
|
||||||
|
|
||||||
:param training_config_path: The training config filepath.
|
:param training_config_path: YAML file containing configurable items defined in
|
||||||
:param lay_down_config_path: The lay down config filepath.
|
`primaite.config.training_config.TrainingConfig`
|
||||||
|
:type training_config_path: Union[path, str]
|
||||||
|
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||||
|
:type lay_down_config_path: Union[path, str]
|
||||||
|
:param session_path: directory path of the session to load
|
||||||
"""
|
"""
|
||||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
|
||||||
|
|
||||||
session.setup()
|
session.setup()
|
||||||
session.learn()
|
session.learn()
|
||||||
@@ -31,9 +36,14 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--tc")
|
parser.add_argument("--tc")
|
||||||
parser.add_argument("--ldc")
|
parser.add_argument("--ldc")
|
||||||
|
parser.add_argument("--load")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not args.tc:
|
if args.load:
|
||||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
run(session_path=args.load)
|
||||||
if not args.ldc:
|
else:
|
||||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
if not args.tc:
|
||||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||||
|
if not args.ldc:
|
||||||
|
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||||
|
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||||
|
|||||||
@@ -35,8 +35,12 @@ class PrimaiteSession:
|
|||||||
"""
|
"""
|
||||||
The PrimaiteSession constructor.
|
The PrimaiteSession constructor.
|
||||||
|
|
||||||
:param training_config_path: The training config path.
|
:param training_config_path: YAML file containing configurable items defined in
|
||||||
:param lay_down_config_path: The lay down config path.
|
`primaite.config.training_config.TrainingConfig`
|
||||||
|
:type training_config_path: Union[path, str]
|
||||||
|
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||||
|
:type lay_down_config_path: Union[path, str]
|
||||||
|
:param session_path: directory path of the session to load
|
||||||
"""
|
"""
|
||||||
self._agent_session: AgentSessionABC = None # noqa
|
self._agent_session: AgentSessionABC = None # noqa
|
||||||
self.session_path: Path = session_path # noqa
|
self.session_path: Path = session_path # noqa
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from uuid import uuid4
|
|||||||
from primaite import getLogger
|
from primaite import getLogger
|
||||||
from primaite.agents.sb3 import SB3Agent
|
from primaite.agents.sb3 import SB3Agent
|
||||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||||
|
from primaite.main import run
|
||||||
from primaite.primaite_session import PrimaiteSession
|
from primaite.primaite_session import PrimaiteSession
|
||||||
from primaite.utils.session_output_reader import av_rewards_dict
|
from primaite.utils.session_output_reader import av_rewards_dict
|
||||||
from tests import TEST_ASSETS_ROOT
|
from tests import TEST_ASSETS_ROOT
|
||||||
@@ -153,3 +154,35 @@ def test_load_primaite_session():
|
|||||||
|
|
||||||
# delete the test directory
|
# delete the test directory
|
||||||
shutil.rmtree(test_path)
|
shutil.rmtree(test_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_loading():
|
||||||
|
"""Test loading session via main.run."""
|
||||||
|
expected_learn_mean_reward_per_episode = {
|
||||||
|
10: 0,
|
||||||
|
11: -0.008037109374999995,
|
||||||
|
12: -0.007978515624999988,
|
||||||
|
13: -0.008191406249999991,
|
||||||
|
14: -0.00817578124999999,
|
||||||
|
15: -0.008085937499999998,
|
||||||
|
16: -0.007837890624999982,
|
||||||
|
17: -0.007798828124999992,
|
||||||
|
18: -0.007777343749999998,
|
||||||
|
19: -0.007958984374999988,
|
||||||
|
20: -0.0077499999999999835,
|
||||||
|
}
|
||||||
|
|
||||||
|
test_path = copy_session_asset(TEST_ASSETS_ROOT / "example_sb3_agent_session")
|
||||||
|
|
||||||
|
# create loaded session
|
||||||
|
run(session_path=test_path)
|
||||||
|
|
||||||
|
learn_mean_rewards = av_rewards_dict(
|
||||||
|
next(Path(test_path).rglob("**/learning/average_reward_per_episode_*.csv"), None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# run is seeded so should have the expected learn value
|
||||||
|
assert learn_mean_rewards == expected_learn_mean_reward_per_episode
|
||||||
|
|
||||||
|
# delete the test directory
|
||||||
|
shutil.rmtree(test_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user