From d4eb499729f21f524a33f76fcefdcb4524e8442b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 27 Oct 2023 12:20:23 +0100 Subject: [PATCH] Start fixing up cli and setup --- sandbox.py | 1 - src/primaite/cli.py | 58 ++------------------------------- src/primaite/config/__init__.py | 2 ++ src/primaite/config/load.py | 22 +++++++++++++ 4 files changed, 26 insertions(+), 57 deletions(-) create mode 100644 src/primaite/config/__init__.py create mode 100644 src/primaite/config/load.py diff --git a/sandbox.py b/sandbox.py index ab5e701f..c5c8ae38 100644 --- a/sandbox.py +++ b/sandbox.py @@ -6,7 +6,6 @@ from primaite.game.session import PrimaiteSession _PRIMAITE_CONFIG["log_level"] = logging.DEBUG print(PRIMAITE_PATHS.app_log_dir_path) -import itertools import yaml diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 9bdc414d..ea0247ec 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -10,7 +10,6 @@ import yaml from typing_extensions import Annotated from primaite import PRIMAITE_PATHS -from primaite.data_viz import PlotlyTemplate app = typer.Typer() @@ -81,14 +80,6 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) -> print(f"PrimAITE Log Level: {level}") -@app.command() -def notebooks() -> None: - """Start Jupyter Lab in the users PrimAITE notebooks directory.""" - from primaite.notebooks import start_jupyter_session - - start_jupyter_session() - - @app.command() def version() -> None: """Get the installed PrimAITE version number.""" @@ -97,14 +88,6 @@ def version() -> None: print(primaite.__version__) -@app.command() -def clean_up() -> None: - """Cleans up left over files from previous version installations.""" - from primaite.setup import old_installation_clean_up - - old_installation_clean_up.run() - - @app.command() def setup(overwrite_existing: bool = True) -> None: """ @@ -113,7 +96,7 @@ def setup(overwrite_existing: bool = True) -> None: WARNING: All user-data will be lost. """ from primaite import getLogger - from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs + from primaite.setup import reset_demo_notebooks, reset_example_configs _LOGGER = getLogger(__name__) @@ -130,19 +113,12 @@ def setup(overwrite_existing: bool = True) -> None: _LOGGER.info("Rebuilding the example notebooks...") reset_example_configs.run(overwrite_existing=True) - _LOGGER.info("Performing a clean-up of previous PrimAITE installations...") - old_installation_clean_up.run() - _LOGGER.info("PrimAITE setup complete!") @app.command() def session( - tc: Optional[str] = None, - ldc: Optional[str] = None, - load: Optional[str] = None, - legacy_tc: bool = False, - legacy_ldc: bool = False, + config: Optional[str] = None, ) -> None: """ Run a PrimAITE session. @@ -165,13 +141,8 @@ def session( legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0. """ from primaite.config.lay_down_config import dos_very_basic_config_path - from primaite.config.training_config import main_training_config_path from primaite.main import run - if load is not None: - # run a loaded session - run(session_path=load) - else: # start a new session using tc and ldc if not tc: @@ -186,28 +157,3 @@ def session( legacy_training_config=legacy_tc, legacy_lay_down_config=legacy_ldc, ) - - -@app.command() -def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None: - """ - View or set the plotly template for Session plots. - - To View, simply call: primaite plotly-template - - To set, call: primaite plotly-template - - For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK - """ - if PRIMAITE_PATHS.app_config_file_path.exists(): - with open(PRIMAITE_PATHS.app_config_file_path, "r") as file: - primaite_config = yaml.safe_load(file) - - if template: - primaite_config["session"]["outputs"]["plots"]["template"] = template.value - with open(PRIMAITE_PATHS.app_config_file_path, "w") as file: - yaml.dump(primaite_config, file) - print(f"PrimAITE plotly template: {template.value}") - else: - template = primaite_config["session"]["outputs"]["plots"]["template"] - print(f"PrimAITE plotly template: {template}") diff --git a/src/primaite/config/__init__.py b/src/primaite/config/__init__.py new file mode 100644 index 00000000..92f5a7d2 --- /dev/null +++ b/src/primaite/config/__init__.py @@ -0,0 +1,2 @@ +# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +"""Configuration parameters for running experiments.""" diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py new file mode 100644 index 00000000..77b76299 --- /dev/null +++ b/src/primaite/config/load.py @@ -0,0 +1,22 @@ +from pathlib import Path +from typing import Union + +import yaml + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def load(file_path: Union[str, Path]) -> Dict: + """ + Read a YAML file and return the contents as a dictionary. + + :param file_path: Path to the YAML file. + :type file_path: Union[str, Path] + :return: Config dictionary + :rtype: Dict + """ + + if not isinstance(file_path, Path): + file_path = Path(file_path)