#2533: fix primaite config recreated when running setup

This commit is contained in:
Czar Echavez
2024-04-30 19:15:54 +01:00
parent 73f64cd89c
commit 729f9c5064
6 changed files with 74 additions and 39 deletions

View File

@@ -123,8 +123,9 @@ PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
def _host_primaite_config() -> None:
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
if not PRIMAITE_PATHS.app_config_file_path.exists():
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_host_primaite_config()

View File

@@ -2,9 +2,12 @@
"""Provides a CLI using Typer as an entry point."""
import logging
import os
import shutil
from enum import Enum
from pathlib import Path
from typing import Optional
import pkg_resources
import typer
import yaml
from typing_extensions import Annotated
@@ -91,7 +94,7 @@ def version() -> None:
@app.command()
def setup(overwrite_existing: bool = True) -> None:
def setup(overwrite_existing: bool = False) -> None:
"""
Perform the PrimAITE first-time setup.
@@ -104,11 +107,14 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Performing the PrimAITE first-time setup...")
_LOGGER.info("Building primaite_config.yaml...")
_LOGGER.info("Building the PrimAITE app directories...")
PRIMAITE_PATHS.mkdirs()
_LOGGER.info("Building primaite_config.yaml...")
if overwrite_existing:
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_LOGGER.info("Rebuilding the demo notebooks...")
reset_demo_notebooks.run(overwrite_existing=True)

View File

@@ -5,9 +5,8 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
from primaite import _PRIMAITE_ROOT, getLogger, PRIMAITE_PATHS
from primaite import getLogger, PRIMAITE_PATHS
from primaite.simulator import LogLevel, SIM_OUTPUT
from primaite.utils.cli.primaite_config_utils import get_primaite_config_dict, is_dev_mode
_LOGGER = getLogger(__name__)
@@ -63,15 +62,15 @@ class PrimaiteIO:
time_str = timestamp.strftime("%H-%M-%S")
# check if running in dev mode
if is_dev_mode():
# if dev mode, simulation output will be the repository root or whichever path is configured
app_config = get_primaite_config_dict()
if app_config["developer_mode"]["output_dir"] is not None:
session_path = app_config["developer_mode"]["output_path"]
else:
session_path = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_str / time_str
else:
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
# if is_dev_mode():
# # if dev mode, simulation output will be the repository root or whichever path is configured
# app_config = get_primaite_config_dict()
# if app_config["developer_mode"]["output_dir"] is not None:
# session_path = app_config["developer_mode"]["output_dir"]
# else:
# session_path = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_str / time_str
# else:
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
session_path.mkdir(exist_ok=True, parents=True)
return session_path

View File

@@ -26,34 +26,63 @@ class LogLevel(IntEnum):
class _SimOutput:
_default_path = _PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
def __init__(self):
self._path: Path = self._default_path
self.save_pcap_logs: bool = False
self.save_sys_logs: bool = False
self.write_sys_log_to_terminal: bool = False
self.sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
if is_dev_mode():
# if dev mode, override with the values configured via the primaite dev-mode command
dev_config = get_primaite_config_dict().get("developer_mode")
self.save_pcap_logs = dev_config["output_pcap_logs"]
self.save_sys_logs = dev_config["output_sys_logs"]
self.write_sys_log_to_terminal = dev_config["output_to_terminal"]
self._path: Path = (
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
self._save_pcap_logs: bool = False
self._save_sys_logs: bool = False
self._write_sys_log_to_terminal: bool = False
self._sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
@property
def path(self) -> Path:
if not is_dev_mode():
return self._path
if is_dev_mode():
dev_config = get_primaite_config_dict().get("developer_mode")
return Path(dev_config["output_dir"]) if dev_config["output_dir"] else self._default_path
return self._path
@path.setter
def path(self, new_path: Path) -> None:
self._path = new_path
self._path.mkdir(exist_ok=True, parents=True)
@property
def save_pcap_logs(self) -> bool:
if is_dev_mode():
return get_primaite_config_dict().get("developer_mode").get("output_pcap_logs")
return self._save_pcap_logs
@save_pcap_logs.setter
def save_pcap_logs(self, save_pcap_logs: bool) -> None:
self._save_pcap_logs = save_pcap_logs
@property
def save_sys_logs(self) -> bool:
if is_dev_mode():
return get_primaite_config_dict().get("developer_mode").get("output_sys_logs")
return self._save_sys_logs
@save_sys_logs.setter
def save_sys_logs(self, save_sys_logs: bool) -> None:
self._save_sys_logs = save_sys_logs
@property
def write_sys_log_to_terminal(self) -> bool:
if is_dev_mode():
return get_primaite_config_dict().get("developer_mode").get("output_to_terminal")
return self._write_sys_log_to_terminal
@write_sys_log_to_terminal.setter
def write_sys_log_to_terminal(self, write_sys_log_to_terminal: bool) -> None:
self._write_sys_log_to_terminal = write_sys_log_to_terminal
@property
def sys_log_level(self) -> LogLevel:
if is_dev_mode():
return LogLevel[get_primaite_config_dict().get("developer_mode").get("sys_log_level")]
return self._sys_log_level
@sys_log_level.setter
def sys_log_level(self, sys_log_level: LogLevel) -> None:
self._sys_log_level = sys_log_level
SIM_OUTPUT = _SimOutput()

View File

@@ -28,7 +28,7 @@ def get_primaite_config_dict(config_path: Optional[Path] = None) -> Dict:
def is_dev_mode() -> bool:
"""Returns True if PrimAITE is currently running in developer mode."""
config = get_primaite_config_dict()
return config["developer_mode"]["enabled"]
return config["developer_mode"]["enabled"] if config.get("developer_mode", {}).get("enabled") else False
def update_primaite_config(config: Dict) -> None:

View File

@@ -3,6 +3,7 @@ import shutil
import tempfile
from pathlib import Path
import pkg_resources
import pytest
from primaite import _PRIMAITE_ROOT, PRIMAITE_PATHS
@@ -19,9 +20,8 @@ def test_setup():
temp_dir = tempfile.gettempdir()
temp_config = Path(temp_dir) / "primaite_config.yaml"
shutil.copyfile(
_PRIMAITE_ROOT / "setup" / "_package_data" / "primaite_config.yaml", temp_config
) # copy the default primaite config to temp directory
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copyfile(pkg_config_path, temp_config) # copy the default primaite config to temp directory
PRIMAITE_PATHS.app_config_file_path = temp_config # use the copy for the test
yield # run test
os.remove(temp_config) # clean up temp file