#915 - Annotated logs func in cli.py to take -n.
- Fixed entry point on main.py - Commented out the print reward line in step func of primaite_env.py. - Added jupyterlab==3.6.1 to pyproject.toml
This commit is contained in:
@@ -26,6 +26,7 @@ classifiers = [
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"gym==0.21.0",
|
"gym==0.21.0",
|
||||||
|
"jupyterlab==3.6.1",
|
||||||
"matplotlib==3.7.1",
|
"matplotlib==3.7.1",
|
||||||
"networkx==3.1",
|
"networkx==3.1",
|
||||||
"numpy==1.23.5",
|
"numpy==1.23.5",
|
||||||
|
|||||||
@@ -2,12 +2,12 @@
|
|||||||
"""Provides a CLI using Typer as an entry point."""
|
"""Provides a CLI using Typer as an entry point."""
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import typer
|
import typer
|
||||||
from platformdirs import PlatformDirs
|
from platformdirs import PlatformDirs
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
@@ -33,26 +33,17 @@ def reset_notebooks(overwrite: bool = True):
|
|||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def logs(last_n: int = 10):
|
def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||||
"""
|
"""
|
||||||
Print the PrimAITE log file.
|
Print the PrimAITE log file.
|
||||||
|
|
||||||
:param last_n: The number of lines to print. Default value is 10.
|
:param last_n: The number of lines to print. Default value is 10.
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
from primaite import LOG_PATH
|
||||||
|
|
||||||
from platformdirs import PlatformDirs
|
if os.path.isfile(LOG_PATH):
|
||||||
|
with open(LOG_PATH) as file:
|
||||||
yt_platform_dirs = PlatformDirs(appname="primaite")
|
|
||||||
|
|
||||||
if sys.platform == "win32":
|
|
||||||
log_dir = yt_platform_dirs.user_data_path / "logs"
|
|
||||||
else:
|
|
||||||
log_dir = yt_platform_dirs.user_log_path
|
|
||||||
log_path = os.path.join(log_dir, "primaite.log")
|
|
||||||
|
|
||||||
if os.path.isfile(log_path):
|
|
||||||
with open(log_path) as file:
|
|
||||||
lines = file.readlines()
|
lines = file.readlines()
|
||||||
for line in lines[-last_n:]:
|
for line in lines[-last_n:]:
|
||||||
print(re.sub(r"\n*", "", line))
|
print(re.sub(r"\n*", "", line))
|
||||||
@@ -89,6 +80,7 @@ def setup(overwrite_existing: bool = True):
|
|||||||
|
|
||||||
WARNING: All user-data will be lost.
|
WARNING: All user-data will be lost.
|
||||||
"""
|
"""
|
||||||
|
# Does this way to avoid using PrimAITE package before config is loaded
|
||||||
app_dirs = PlatformDirs(appname="primaite")
|
app_dirs = PlatformDirs(appname="primaite")
|
||||||
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
|
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
|
||||||
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
||||||
|
|||||||
@@ -377,7 +377,7 @@ class Primaite(Env):
|
|||||||
self.step_count,
|
self.step_count,
|
||||||
self.training_config,
|
self.training_config,
|
||||||
)
|
)
|
||||||
print(f" Step {self.step_count} Reward: {str(reward)}")
|
#print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||||
self.total_reward += reward
|
self.total_reward += reward
|
||||||
if self.step_count == self.episode_steps:
|
if self.step_count == self.episode_steps:
|
||||||
self.average_reward = self.total_reward / self.step_count
|
self.average_reward = self.total_reward / self.step_count
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ TODO: This will eventually be refactored out into a proper Session class.
|
|||||||
TODO: The passing about of session_dir and timestamp_str is temporary and
|
TODO: The passing about of session_dir and timestamp_str is temporary and
|
||||||
will be cleaned up once we move to a proper Session class.
|
will be cleaned up once we move to a proper Session class.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -19,10 +20,10 @@ from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
|||||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||||
|
|
||||||
from primaite import SESSIONS_DIR, getLogger
|
from primaite import SESSIONS_DIR, getLogger
|
||||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
from primaite.config.training_config import TrainingConfig
|
||||||
from primaite.config.training_config import TrainingConfig, main_training_config_path
|
|
||||||
from primaite.environment.primaite_env import Primaite
|
from primaite.environment.primaite_env import Primaite
|
||||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
from primaite.transactions.transactions_to_file import \
|
||||||
|
write_transaction_to_file
|
||||||
|
|
||||||
_LOGGER = getLogger(__name__)
|
_LOGGER = getLogger(__name__)
|
||||||
|
|
||||||
@@ -334,19 +335,19 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
|
|||||||
_LOGGER.debug("Finished")
|
_LOGGER.debug("Finished")
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# parser.add_argument("--tc")
|
parser.add_argument("--tc")
|
||||||
# parser.add_argument("--ldc")
|
parser.add_argument("--ldc")
|
||||||
# args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# if not args.tc:
|
if not args.tc:
|
||||||
# _LOGGER.error(
|
_LOGGER.error(
|
||||||
# "Please provide a training config file using the --tc " "argument"
|
"Please provide a training config file using the --tc " "argument"
|
||||||
# )
|
)
|
||||||
# if not args.ldc:
|
if not args.ldc:
|
||||||
# _LOGGER.error(
|
_LOGGER.error(
|
||||||
# "Please provide a lay down config file using the --ldc " "argument"
|
"Please provide a lay down config file using the --ldc " "argument"
|
||||||
# )
|
)
|
||||||
# run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||||
|
|
||||||
|
|
||||||
run(main_training_config_path(), data_manipulation_config_path())
|
|
||||||
|
|||||||
@@ -17,20 +17,16 @@ def start_jupyter_session():
|
|||||||
|
|
||||||
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
||||||
"""
|
"""
|
||||||
if sys.platform == "win32":
|
|
||||||
if importlib.util.find_spec("jupyter") is not None:
|
if importlib.util.find_spec("jupyter") is not None:
|
||||||
# Jupyter is installed
|
jupyter_cmd = "python3 -m jupyter lab"
|
||||||
working_dir = os.getcwd()
|
if sys.platform == "win32":
|
||||||
os.chdir(NOTEBOOKS_DIR)
|
jupyter_cmd = "jupyter lab"
|
||||||
subprocess.Popen("jupyter lab")
|
|
||||||
os.chdir(working_dir)
|
working_dir = os.getcwd()
|
||||||
else:
|
os.chdir(NOTEBOOKS_DIR)
|
||||||
# Jupyter is not installed
|
subprocess.Popen(jupyter_cmd)
|
||||||
_LOGGER.error("Cannot start jupyter lab as it is not installed")
|
os.chdir(working_dir)
|
||||||
else:
|
else:
|
||||||
msg = (
|
# Jupyter is not installed
|
||||||
"Feature currently only supported on Windows OS. For "
|
_LOGGER.error("Cannot start jupyter lab as it is not installed")
|
||||||
"Linux/MacOS users, run 'cd ~/primaite/notebooks; jupyter "
|
|
||||||
"lab' from your Python environment."
|
|
||||||
)
|
|
||||||
_LOGGER.warning(msg)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user