diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index c58f0103..2cd44755 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -134,23 +134,15 @@ _host_primaite_config() def _get_primaite_config() -> Dict: config_path = PRIMAITE_PATHS.app_config_file_path if not config_path.exists(): + # load from package if config does not exist config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) with open(config_path, "r") as file: + # load from config primaite_config = yaml.safe_load(file) - log_level_map = { - "NOTSET": logging.NOTSET, - "DEBUG": logging.DEBUG, - "INFO": logging.INFO, - "WARN": logging.WARN, - "WARNING": logging.WARN, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL, - } - primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] - return primaite_config + return primaite_config -_PRIMAITE_CONFIG = _get_primaite_config() +PRIMAITE_CONFIG = _get_primaite_config() class _LevelFormatter(Formatter): @@ -177,11 +169,11 @@ class _LevelFormatter(Formatter): _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { - logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], - logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], - logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], - logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], - logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"], + logging.DEBUG: PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], + logging.INFO: PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], + logging.WARNING: PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], + logging.ERROR: PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], + logging.CRITICAL: PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"], } ) @@ -193,10 +185,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( backupCount=9, # Max 100MB of logs encoding="utf8", ) -_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) -_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) +_STREAM_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"]) +_FILE_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"]) -_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"] +_LOG_FORMAT_STR: Final[str] = PRIMAITE_CONFIG["logging"]["logger_format"] _STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER) _FILE_HANDLER.setFormatter(_LEVEL_FORMATTER) @@ -215,6 +207,6 @@ def getLogger(name: str) -> Logger: # noqa logging config. """ logger = logging.getLogger(name) - logger.setLevel(_PRIMAITE_CONFIG["log_level"]) + logger.setLevel(PRIMAITE_CONFIG["logging"]["log_level"]) return logger diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index cde7136c..9f936249 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -3,11 +3,11 @@ from datetime import datetime from enum import IntEnum from pathlib import Path -from primaite import _PRIMAITE_ROOT +from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG __all__ = ["SIM_OUTPUT"] -from primaite.utils.cli.primaite_config_utils import get_primaite_config_dict, is_dev_mode +from primaite.utils.cli.primaite_config_utils import is_dev_mode class LogLevel(IntEnum): @@ -47,7 +47,7 @@ class _SimOutput: @property def save_pcap_logs(self) -> bool: if is_dev_mode(): - return get_primaite_config_dict().get("developer_mode").get("output_pcap_logs") + return PRIMAITE_CONFIG.get("developer_mode").get("output_pcap_logs") return self._save_pcap_logs @save_pcap_logs.setter @@ -57,7 +57,7 @@ class _SimOutput: @property def save_sys_logs(self) -> bool: if is_dev_mode(): - return get_primaite_config_dict().get("developer_mode").get("output_sys_logs") + return PRIMAITE_CONFIG.get("developer_mode").get("output_sys_logs") return self._save_sys_logs @save_sys_logs.setter @@ -67,7 +67,7 @@ class _SimOutput: @property def write_sys_log_to_terminal(self) -> bool: if is_dev_mode(): - return get_primaite_config_dict().get("developer_mode").get("output_to_terminal") + return PRIMAITE_CONFIG.get("developer_mode").get("output_to_terminal") return self._write_sys_log_to_terminal @write_sys_log_to_terminal.setter @@ -77,7 +77,7 @@ class _SimOutput: @property def sys_log_level(self) -> LogLevel: if is_dev_mode(): - return LogLevel[get_primaite_config_dict().get("developer_mode").get("sys_log_level")] + return LogLevel[PRIMAITE_CONFIG.get("developer_mode").get("sys_log_level")] return self._sys_log_level @sys_log_level.setter diff --git a/src/primaite/utils/cli/dev_cli.py b/src/primaite/utils/cli/dev_cli.py index 03567785..8d426b2d 100644 --- a/src/primaite/utils/cli/dev_cli.py +++ b/src/primaite/utils/cli/dev_cli.py @@ -3,9 +3,9 @@ import typer from rich import print from typing_extensions import Annotated -from primaite import _PRIMAITE_ROOT +from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG from primaite.simulator import LogLevel -from primaite.utils.cli.primaite_config_utils import get_primaite_config_dict, is_dev_mode, update_primaite_config +from primaite.utils.cli.primaite_config_utils import is_dev_mode, update_primaite_application_config dev = typer.Typer() @@ -45,28 +45,18 @@ def show(): @dev.command() def enable(): """Enable the development mode for PrimAITE.""" - config_dict = get_primaite_config_dict() - - if config_dict is None: - return - # enable dev mode - config_dict["developer_mode"]["enabled"] = True - update_primaite_config(config_dict) + PRIMAITE_CONFIG["developer_mode"]["enabled"] = True + update_primaite_application_config() print(DEVELOPMENT_MODE_MESSAGE) @dev.command() def disable(): """Disable the development mode for PrimAITE.""" - config_dict = get_primaite_config_dict() - - if config_dict is None: - return - # disable dev mode - config_dict["developer_mode"]["enabled"] = False - update_primaite_config(config_dict) + PRIMAITE_CONFIG["developer_mode"]["enabled"] = False + update_primaite_application_config() print(PRODUCTION_MODE_MESSAGE) @@ -105,29 +95,24 @@ def config_callback( ] = None, ): """Configure the development tools and environment.""" - config_dict = get_primaite_config_dict() - - if config_dict is None: - return - if ctx.params.get("sys_log_level") is not None: - config_dict["developer_mode"]["sys_log_level"] = ctx.params.get("sys_log_level") + PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] = ctx.params.get("sys_log_level") print(f"PrimAITE dev-mode config updated sys_log_level={ctx.params.get('sys_log_level')}") if output_sys_logs is not None: - config_dict["developer_mode"]["output_sys_logs"] = output_sys_logs + PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] = output_sys_logs print(f"PrimAITE dev-mode config updated {output_sys_logs=}") if output_pcap_logs is not None: - config_dict["developer_mode"]["output_pcap_logs"] = output_pcap_logs + PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] = output_pcap_logs print(f"PrimAITE dev-mode config updated {output_pcap_logs=}") if output_to_terminal is not None: - config_dict["developer_mode"]["output_to_terminal"] = output_to_terminal + PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] = output_to_terminal print(f"PrimAITE dev-mode config updated {output_to_terminal=}") # update application config - update_primaite_config(config_dict) + update_primaite_application_config() config_typer = typer.Typer( @@ -159,15 +144,10 @@ def path( ] = None, ): """Set the output directory for the PrimAITE system and PCAP logs.""" - config_dict = get_primaite_config_dict() - - if config_dict is None: - return - if default: - config_dict["developer_mode"]["output_dir"] = None + PRIMAITE_CONFIG["developer_mode"]["output_dir"] = None # update application config - update_primaite_config(config_dict) + update_primaite_application_config() print( f"PrimAITE dev-mode output_dir [medium_turquoise]" f"{str(_PRIMAITE_ROOT.parent.parent / 'simulation_output')}" @@ -176,7 +156,7 @@ def path( return if directory: - config_dict["developer_mode"]["output_dir"] = directory + PRIMAITE_CONFIG["developer_mode"]["output_dir"] = directory # update application config - update_primaite_config(config_dict) + update_primaite_application_config() print(f"PrimAITE dev-mode output_dir [medium_turquoise]{directory}[/medium_turquoise]") diff --git a/src/primaite/utils/cli/primaite_config_utils.py b/src/primaite/utils/cli/primaite_config_utils.py index a11c2ce3..e0f6fe56 100644 --- a/src/primaite/utils/cli/primaite_config_utils.py +++ b/src/primaite/utils/cli/primaite_config_utils.py @@ -1,37 +1,18 @@ -from pathlib import Path -from typing import Dict, Optional - import yaml -from primaite import PRIMAITE_PATHS - - -def get_primaite_config_dict(config_path: Optional[Path] = None) -> Dict: - """ - Returns a dict containing the PrimAITE application config. - - :param: config_path: takes in a path object - leave empty to use the default app config path - """ - err_msg = "PrimAITE application config could not be loaded." - - if config_path is None: - config_path = PRIMAITE_PATHS.app_config_file_path - err_msg = "PrimAITE application config was not found. Have you run `primaite setup`?" - - if config_path.exists(): - with open(config_path, "r") as file: - return yaml.safe_load(file) - else: - print(err_msg) +from primaite import PRIMAITE_CONFIG, PRIMAITE_PATHS def is_dev_mode() -> bool: """Returns True if PrimAITE is currently running in developer mode.""" - config = get_primaite_config_dict() - return config["developer_mode"]["enabled"] if config.get("developer_mode", {}).get("enabled") else False + return ( + PRIMAITE_CONFIG["developer_mode"]["enabled"] + if (PRIMAITE_CONFIG.get("developer_mode", {}).get("enabled")) + else False + ) -def update_primaite_config(config: Dict) -> None: +def update_primaite_application_config() -> None: """Update the PrimAITE application config file.""" with open(PRIMAITE_PATHS.app_config_file_path, "w") as file: - yaml.dump(config, file) + yaml.dump(PRIMAITE_CONFIG, file)