diff --git a/pyproject.toml b/pyproject.toml index 66840f1b..7ddf7710 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ dependencies = [ "gym==0.21.0", + "jupyterlab==3.6.1", "matplotlib==3.7.1", "networkx==3.1", "numpy==1.23.5", diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 4d973179..1abf625c 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -2,12 +2,12 @@ """Provides a CLI using Typer as an entry point.""" import os import shutil -import sys from pathlib import Path import pkg_resources import typer from platformdirs import PlatformDirs +from typing_extensions import Annotated app = typer.Typer() @@ -33,26 +33,17 @@ def reset_notebooks(overwrite: bool = True): @app.command() -def logs(last_n: int = 10): +def logs(last_n: Annotated[int, typer.Option("-n")]): """ Print the PrimAITE log file. :param last_n: The number of lines to print. Default value is 10. """ import re + from primaite import LOG_PATH - from platformdirs import PlatformDirs - - yt_platform_dirs = PlatformDirs(appname="primaite") - - if sys.platform == "win32": - log_dir = yt_platform_dirs.user_data_path / "logs" - else: - log_dir = yt_platform_dirs.user_log_path - log_path = os.path.join(log_dir, "primaite.log") - - if os.path.isfile(log_path): - with open(log_path) as file: + if os.path.isfile(LOG_PATH): + with open(LOG_PATH) as file: lines = file.readlines() for line in lines[-last_n:]: print(re.sub(r"\n*", "", line)) @@ -89,6 +80,7 @@ def setup(overwrite_existing: bool = True): WARNING: All user-data will be lost. """ + # Does this way to avoid using PrimAITE package before config is loaded app_dirs = PlatformDirs(appname="primaite") app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) user_config_path = app_dirs.user_config_path / "primaite_config.yaml" diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 71537817..d3c45c79 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -377,7 +377,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - print(f" Step {self.step_count} Reward: {str(reward)}") + #print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count diff --git a/src/primaite/main.py b/src/primaite/main.py index 0ec7d8ef..ac32a018 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -6,6 +6,7 @@ TODO: This will eventually be refactored out into a proper Session class. TODO: The passing about of session_dir and timestamp_str is temporary and will be cleaned up once we move to a proper Session class. """ +import argparse import json import time from datetime import datetime @@ -19,10 +20,10 @@ from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import SESSIONS_DIR, getLogger -from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.config.training_config import TrainingConfig, main_training_config_path +from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import write_transaction_to_file +from primaite.transactions.transactions_to_file import \ + write_transaction_to_file _LOGGER = getLogger(__name__) @@ -334,19 +335,19 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, _LOGGER.debug("Finished") -# if __name__ == "__main__": -# parser = argparse.ArgumentParser() -# parser.add_argument("--tc") -# parser.add_argument("--ldc") -# args = parser.parse_args() -# if not args.tc: -# _LOGGER.error( -# "Please provide a training config file using the --tc " "argument" -# ) -# if not args.ldc: -# _LOGGER.error( -# "Please provide a lay down config file using the --ldc " "argument" -# ) -# run(training_config_path=args.tc, lay_down_config_path=args.ldc) +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tc") + parser.add_argument("--ldc") + args = parser.parse_args() + if not args.tc: + _LOGGER.error( + "Please provide a training config file using the --tc " "argument" + ) + if not args.ldc: + _LOGGER.error( + "Please provide a lay down config file using the --ldc " "argument" + ) + run(training_config_path=args.tc, lay_down_config_path=args.ldc) + -run(main_training_config_path(), data_manipulation_config_path()) diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index d8c93b9a..6d822961 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -17,20 +17,16 @@ def start_jupyter_session(): .. todo:: Figure out how to get this working for Linux and MacOS too. """ - if sys.platform == "win32": - if importlib.util.find_spec("jupyter") is not None: - # Jupyter is installed - working_dir = os.getcwd() - os.chdir(NOTEBOOKS_DIR) - subprocess.Popen("jupyter lab") - os.chdir(working_dir) - else: - # Jupyter is not installed - _LOGGER.error("Cannot start jupyter lab as it is not installed") + + if importlib.util.find_spec("jupyter") is not None: + jupyter_cmd = "python3 -m jupyter lab" + if sys.platform == "win32": + jupyter_cmd = "jupyter lab" + + working_dir = os.getcwd() + os.chdir(NOTEBOOKS_DIR) + subprocess.Popen(jupyter_cmd) + os.chdir(working_dir) else: - msg = ( - "Feature currently only supported on Windows OS. For " - "Linux/MacOS users, run 'cd ~/primaite/notebooks; jupyter " - "lab' from your Python environment." - ) - _LOGGER.warning(msg) + # Jupyter is not installed + _LOGGER.error("Cannot start jupyter lab as it is not installed")