diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 789517f7..28245d33 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -29,6 +29,15 @@ class _PrimaitePaths: def __init__(self) -> None: self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__) + self.user_home_path = self.generate_user_home_path() + self.user_sessions_path = self.generate_user_sessions_path() + self.user_config_path = self.generate_user_config_path() + self.user_notebooks_path = self.generate_user_notebooks_path() + self.app_home_path = self.generate_app_home_path() + self.app_config_dir_path = self.generate_app_config_dir_path() + self.app_config_file_path = self.generate_app_config_file_path() + self.app_log_dir_path = self.generate_app_log_dir_path() + self.app_log_file_path = self.generate_app_log_file_path() def _get_dirs_properties(self) -> List[str]: class_items = self.__class__.__dict__.items() @@ -43,55 +52,47 @@ class _PrimaitePaths: for p in self._get_dirs_properties(): getattr(self, p) - @property - def user_home_path(self) -> Path: + def generate_user_home_path(self) -> Path: """The PrimAITE user home path.""" path = Path.home() / "primaite" / __version__ path.mkdir(exist_ok=True, parents=True) return path - @property - def user_sessions_path(self) -> Path: + def generate_user_sessions_path(self) -> Path: """The PrimAITE user sessions path.""" path = self.user_home_path / "sessions" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_config_path(self) -> Path: + def generate_user_config_path(self) -> Path: """The PrimAITE user config path.""" path = self.user_home_path / "config" path.mkdir(exist_ok=True, parents=True) return path - @property - def user_notebooks_path(self) -> Path: + def generate_user_notebooks_path(self) -> Path: """The PrimAITE user notebooks path.""" path = self.user_home_path / "notebooks" path.mkdir(exist_ok=True, parents=True) return path - @property - def app_home_path(self) -> Path: + def generate_app_home_path(self) -> Path: """The PrimAITE app home path.""" path = self._dirs.user_data_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_dir_path(self) -> Path: + def generate_app_config_dir_path(self) -> Path: """The PrimAITE app config directory path.""" path = self._dirs.user_config_path path.mkdir(exist_ok=True, parents=True) return path - @property - def app_config_file_path(self) -> Path: + def generate_app_config_file_path(self) -> Path: """The PrimAITE app config file path.""" return self.app_config_dir_path / "primaite_config.yaml" - @property - def app_log_dir_path(self) -> Path: + def generate_app_log_dir_path(self) -> Path: """The PrimAITE app log directory path.""" if sys.platform == "win32": path = self.app_home_path / "logs" @@ -100,8 +101,7 @@ class _PrimaitePaths: path.mkdir(exist_ok=True, parents=True) return path - @property - def app_log_file_path(self) -> Path: + def generate_app_log_file_path(self) -> Path: """The PrimAITE app log file path.""" return self.app_log_dir_path / "primaite.log" diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 2ab7a2cc..1c9104d1 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 0 + n_learn_episodes: 0 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index dca9620f..201528eb 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 2560 + n_learn_episodes: 10 n_eval_episodes: 5 max_steps_per_episode: 128 deterministic_eval: false diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 5f0cfc77..1ed10212 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -2,7 +2,7 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_steps: 2560 + n_learn_episodes: 10 n_eval_episodes: 0 max_steps_per_episode: 128 deterministic_eval: false diff --git a/tests/conftest.py b/tests/conftest.py index 60b69a1e..fe450213 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import tempfile from datetime import datetime from pathlib import Path from typing import Any, Dict, Union -from unittest.mock import patch import nodeenv import pytest @@ -22,13 +21,15 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.service import Service -from tests.mock_and_patch.get_session_path_mock import get_temp_session_path +from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) +from primaite import PRIMAITE_PATHS + # PrimAITE v3 stuff from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.base import Node @@ -97,8 +98,9 @@ class TempPrimaiteSession(PrimaiteSession): @pytest.fixture -def temp_primaite_session(request) -> TempPrimaiteSession: +def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession: """Create a temporary PrimaiteSession object.""" - - config_path = request.param[0] - return TempPrimaiteSession.from_config(config_path=config_path) + with monkeypatch.context() as m: + m.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) + config_path = request.param[0] + return TempPrimaiteSession.from_config(config_path=config_path) diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index 16c4a274..06fe5893 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -9,7 +9,7 @@ from primaite import getLogger _LOGGER = getLogger(__name__) -def get_temp_session_path(session_timestamp: datetime) -> Path: +def temp_user_sessions_path() -> Path: """ Get a temp directory session path the test session will output to.