Start fixing up cli and setup
This commit is contained in:
@@ -6,7 +6,6 @@ from primaite.game.session import PrimaiteSession
|
||||
|
||||
_PRIMAITE_CONFIG["log_level"] = logging.DEBUG
|
||||
print(PRIMAITE_PATHS.app_log_dir_path)
|
||||
import itertools
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import yaml
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.data_viz import PlotlyTemplate
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -81,14 +80,6 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) ->
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks() -> None:
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
start_jupyter_session()
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
"""Get the installed PrimAITE version number."""
|
||||
@@ -97,14 +88,6 @@ def version() -> None:
|
||||
print(primaite.__version__)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up() -> None:
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
old_installation_clean_up.run()
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
@@ -113,7 +96,7 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from primaite import getLogger
|
||||
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
|
||||
from primaite.setup import reset_demo_notebooks, reset_example_configs
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -130,19 +113,12 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
|
||||
old_installation_clean_up.run()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
tc: Optional[str] = None,
|
||||
ldc: Optional[str] = None,
|
||||
load: Optional[str] = None,
|
||||
legacy_tc: bool = False,
|
||||
legacy_ldc: bool = False,
|
||||
config: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
@@ -165,13 +141,8 @@ def session(
|
||||
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if load is not None:
|
||||
# run a loaded session
|
||||
run(session_path=load)
|
||||
|
||||
else:
|
||||
# start a new session using tc and ldc
|
||||
if not tc:
|
||||
@@ -186,28 +157,3 @@ def session(
|
||||
legacy_training_config=legacy_tc,
|
||||
legacy_lay_down_config=legacy_ldc,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
To set, call: primaite plotly-template <desired template>
|
||||
|
||||
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
|
||||
"""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if template:
|
||||
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE plotly template: {template.value}")
|
||||
else:
|
||||
template = primaite_config["session"]["outputs"]["plots"]["template"]
|
||||
print(f"PrimAITE plotly template: {template}")
|
||||
|
||||
2
src/primaite/config/__init__.py
Normal file
2
src/primaite/config/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Configuration parameters for running experiments."""
|
||||
22
src/primaite/config/load.py
Normal file
22
src/primaite/config/load.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path]) -> Dict:
|
||||
"""
|
||||
Read a YAML file and return the contents as a dictionary.
|
||||
|
||||
:param file_path: Path to the YAML file.
|
||||
:type file_path: Union[str, Path]
|
||||
:return: Config dictionary
|
||||
:rtype: Dict
|
||||
"""
|
||||
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
Reference in New Issue
Block a user