#915 - Annotated logs func in cli.py to take -n.
- Fixed entry point on main.py - Commented out the print reward line in step func of primaite_env.py. - Added jupyterlab==3.6.1 to pyproject.toml
This commit is contained in:
@@ -26,6 +26,7 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
"gym==0.21.0",
|
||||
"jupyterlab==3.6.1",
|
||||
"matplotlib==3.7.1",
|
||||
"networkx==3.1",
|
||||
"numpy==1.23.5",
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
"""Provides a CLI using Typer as an entry point."""
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
import typer
|
||||
from platformdirs import PlatformDirs
|
||||
from typing_extensions import Annotated
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -33,26 +33,17 @@ def reset_notebooks(overwrite: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def logs(last_n: int = 10):
|
||||
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
"""
|
||||
Print the PrimAITE log file.
|
||||
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
import re
|
||||
from primaite import LOG_PATH
|
||||
|
||||
from platformdirs import PlatformDirs
|
||||
|
||||
yt_platform_dirs = PlatformDirs(appname="primaite")
|
||||
|
||||
if sys.platform == "win32":
|
||||
log_dir = yt_platform_dirs.user_data_path / "logs"
|
||||
else:
|
||||
log_dir = yt_platform_dirs.user_log_path
|
||||
log_path = os.path.join(log_dir, "primaite.log")
|
||||
|
||||
if os.path.isfile(log_path):
|
||||
with open(log_path) as file:
|
||||
if os.path.isfile(LOG_PATH):
|
||||
with open(LOG_PATH) as file:
|
||||
lines = file.readlines()
|
||||
for line in lines[-last_n:]:
|
||||
print(re.sub(r"\n*", "", line))
|
||||
@@ -89,6 +80,7 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
# Does this way to avoid using PrimAITE package before config is loaded
|
||||
app_dirs = PlatformDirs(appname="primaite")
|
||||
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
|
||||
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
||||
|
||||
@@ -377,7 +377,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
#print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
|
||||
@@ -6,6 +6,7 @@ TODO: This will eventually be refactored out into a proper Session class.
|
||||
TODO: The passing about of session_dir and timestamp_str is temporary and
|
||||
will be cleaned up once we move to a proper Session class.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -19,10 +20,10 @@ from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.config.training_config import TrainingConfig, main_training_config_path
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -334,19 +335,19 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
|
||||
_LOGGER.debug("Finished")
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--tc")
|
||||
# parser.add_argument("--ldc")
|
||||
# args = parser.parse_args()
|
||||
# if not args.tc:
|
||||
# _LOGGER.error(
|
||||
# "Please provide a training config file using the --tc " "argument"
|
||||
# )
|
||||
# if not args.ldc:
|
||||
# _LOGGER.error(
|
||||
# "Please provide a lay down config file using the --ldc " "argument"
|
||||
# )
|
||||
# run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error(
|
||||
"Please provide a training config file using the --tc " "argument"
|
||||
)
|
||||
if not args.ldc:
|
||||
_LOGGER.error(
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
|
||||
run(main_training_config_path(), data_manipulation_config_path())
|
||||
|
||||
@@ -17,20 +17,16 @@ def start_jupyter_session():
|
||||
|
||||
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
|
||||
if importlib.util.find_spec("jupyter") is not None:
|
||||
# Jupyter is installed
|
||||
jupyter_cmd = "python3 -m jupyter lab"
|
||||
if sys.platform == "win32":
|
||||
jupyter_cmd = "jupyter lab"
|
||||
|
||||
working_dir = os.getcwd()
|
||||
os.chdir(NOTEBOOKS_DIR)
|
||||
subprocess.Popen("jupyter lab")
|
||||
subprocess.Popen(jupyter_cmd)
|
||||
os.chdir(working_dir)
|
||||
else:
|
||||
# Jupyter is not installed
|
||||
_LOGGER.error("Cannot start jupyter lab as it is not installed")
|
||||
else:
|
||||
msg = (
|
||||
"Feature currently only supported on Windows OS. For "
|
||||
"Linux/MacOS users, run 'cd ~/primaite/notebooks; jupyter "
|
||||
"lab' from your Python environment."
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
Reference in New Issue
Block a user