#915 - Annotated logs func in cli.py to take -n.

- Fixed entry point on main.py
- Commented out the print reward line in step func of primaite_env.py.
- Added jupyterlab==3.6.1 to pyproject.toml
This commit is contained in:
Chris McCarthy
2023-06-09 16:44:49 +01:00
parent 1a7d158d77
commit 149a534851
5 changed files with 39 additions and 49 deletions

View File

@@ -26,6 +26,7 @@ classifiers = [
dependencies = [ dependencies = [
"gym==0.21.0", "gym==0.21.0",
"jupyterlab==3.6.1",
"matplotlib==3.7.1", "matplotlib==3.7.1",
"networkx==3.1", "networkx==3.1",
"numpy==1.23.5", "numpy==1.23.5",

View File

@@ -2,12 +2,12 @@
"""Provides a CLI using Typer as an entry point.""" """Provides a CLI using Typer as an entry point."""
import os import os
import shutil import shutil
import sys
from pathlib import Path from pathlib import Path
import pkg_resources import pkg_resources
import typer import typer
from platformdirs import PlatformDirs from platformdirs import PlatformDirs
from typing_extensions import Annotated
app = typer.Typer() app = typer.Typer()
@@ -33,26 +33,17 @@ def reset_notebooks(overwrite: bool = True):
@app.command() @app.command()
def logs(last_n: int = 10): def logs(last_n: Annotated[int, typer.Option("-n")]):
""" """
Print the PrimAITE log file. Print the PrimAITE log file.
:param last_n: The number of lines to print. Default value is 10. :param last_n: The number of lines to print. Default value is 10.
""" """
import re import re
from primaite import LOG_PATH
from platformdirs import PlatformDirs if os.path.isfile(LOG_PATH):
with open(LOG_PATH) as file:
yt_platform_dirs = PlatformDirs(appname="primaite")
if sys.platform == "win32":
log_dir = yt_platform_dirs.user_data_path / "logs"
else:
log_dir = yt_platform_dirs.user_log_path
log_path = os.path.join(log_dir, "primaite.log")
if os.path.isfile(log_path):
with open(log_path) as file:
lines = file.readlines() lines = file.readlines()
for line in lines[-last_n:]: for line in lines[-last_n:]:
print(re.sub(r"\n*", "", line)) print(re.sub(r"\n*", "", line))
@@ -89,6 +80,7 @@ def setup(overwrite_existing: bool = True):
WARNING: All user-data will be lost. WARNING: All user-data will be lost.
""" """
# Does this way to avoid using PrimAITE package before config is loaded
app_dirs = PlatformDirs(appname="primaite") app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml" user_config_path = app_dirs.user_config_path / "primaite_config.yaml"

View File

@@ -377,7 +377,7 @@ class Primaite(Env):
self.step_count, self.step_count,
self.training_config, self.training_config,
) )
print(f" Step {self.step_count} Reward: {str(reward)}") #print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward self.total_reward += reward
if self.step_count == self.episode_steps: if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count self.average_reward = self.total_reward / self.step_count

View File

@@ -6,6 +6,7 @@ TODO: This will eventually be refactored out into a proper Session class.
TODO: The passing about of session_dir and timestamp_str is temporary and TODO: The passing about of session_dir and timestamp_str is temporary and
will be cleaned up once we move to a proper Session class. will be cleaned up once we move to a proper Session class.
""" """
import argparse
import json import json
import time import time
from datetime import datetime from datetime import datetime
@@ -19,10 +20,10 @@ from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger from primaite import SESSIONS_DIR, getLogger
from primaite.config.lay_down_config import data_manipulation_config_path from primaite.config.training_config import TrainingConfig
from primaite.config.training_config import TrainingConfig, main_training_config_path
from primaite.environment.primaite_env import Primaite from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import write_transaction_to_file from primaite.transactions.transactions_to_file import \
write_transaction_to_file
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@@ -334,19 +335,19 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
_LOGGER.debug("Finished") _LOGGER.debug("Finished")
# if __name__ == "__main__": if __name__ == "__main__":
# parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# parser.add_argument("--tc") parser.add_argument("--tc")
# parser.add_argument("--ldc") parser.add_argument("--ldc")
# args = parser.parse_args() args = parser.parse_args()
# if not args.tc: if not args.tc:
# _LOGGER.error( _LOGGER.error(
# "Please provide a training config file using the --tc " "argument" "Please provide a training config file using the --tc " "argument"
# ) )
# if not args.ldc: if not args.ldc:
# _LOGGER.error( _LOGGER.error(
# "Please provide a lay down config file using the --ldc " "argument" "Please provide a lay down config file using the --ldc " "argument"
# ) )
# run(training_config_path=args.tc, lay_down_config_path=args.ldc) run(training_config_path=args.tc, lay_down_config_path=args.ldc)
run(main_training_config_path(), data_manipulation_config_path())

View File

@@ -17,20 +17,16 @@ def start_jupyter_session():
.. todo:: Figure out how to get this working for Linux and MacOS too. .. todo:: Figure out how to get this working for Linux and MacOS too.
""" """
if sys.platform == "win32":
if importlib.util.find_spec("jupyter") is not None: if importlib.util.find_spec("jupyter") is not None:
# Jupyter is installed jupyter_cmd = "python3 -m jupyter lab"
working_dir = os.getcwd() if sys.platform == "win32":
os.chdir(NOTEBOOKS_DIR) jupyter_cmd = "jupyter lab"
subprocess.Popen("jupyter lab")
os.chdir(working_dir) working_dir = os.getcwd()
else: os.chdir(NOTEBOOKS_DIR)
# Jupyter is not installed subprocess.Popen(jupyter_cmd)
_LOGGER.error("Cannot start jupyter lab as it is not installed") os.chdir(working_dir)
else: else:
msg = ( # Jupyter is not installed
"Feature currently only supported on Windows OS. For " _LOGGER.error("Cannot start jupyter lab as it is not installed")
"Linux/MacOS users, run 'cd ~/primaite/notebooks; jupyter "
"lab' from your Python environment."
)
_LOGGER.warning(msg)