Make pytest patch with temporary session dir
This commit is contained in:
@@ -29,6 +29,15 @@ class _PrimaitePaths:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
|
||||
self.user_home_path = self.generate_user_home_path()
|
||||
self.user_sessions_path = self.generate_user_sessions_path()
|
||||
self.user_config_path = self.generate_user_config_path()
|
||||
self.user_notebooks_path = self.generate_user_notebooks_path()
|
||||
self.app_home_path = self.generate_app_home_path()
|
||||
self.app_config_dir_path = self.generate_app_config_dir_path()
|
||||
self.app_config_file_path = self.generate_app_config_file_path()
|
||||
self.app_log_dir_path = self.generate_app_log_dir_path()
|
||||
self.app_log_file_path = self.generate_app_log_file_path()
|
||||
|
||||
def _get_dirs_properties(self) -> List[str]:
|
||||
class_items = self.__class__.__dict__.items()
|
||||
@@ -43,55 +52,47 @@ class _PrimaitePaths:
|
||||
for p in self._get_dirs_properties():
|
||||
getattr(self, p)
|
||||
|
||||
@property
|
||||
def user_home_path(self) -> Path:
|
||||
def generate_user_home_path(self) -> Path:
|
||||
"""The PrimAITE user home path."""
|
||||
path = Path.home() / "primaite" / __version__
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_sessions_path(self) -> Path:
|
||||
def generate_user_sessions_path(self) -> Path:
|
||||
"""The PrimAITE user sessions path."""
|
||||
path = self.user_home_path / "sessions"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_config_path(self) -> Path:
|
||||
def generate_user_config_path(self) -> Path:
|
||||
"""The PrimAITE user config path."""
|
||||
path = self.user_home_path / "config"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_notebooks_path(self) -> Path:
|
||||
def generate_user_notebooks_path(self) -> Path:
|
||||
"""The PrimAITE user notebooks path."""
|
||||
path = self.user_home_path / "notebooks"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_home_path(self) -> Path:
|
||||
def generate_app_home_path(self) -> Path:
|
||||
"""The PrimAITE app home path."""
|
||||
path = self._dirs.user_data_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_dir_path(self) -> Path:
|
||||
def generate_app_config_dir_path(self) -> Path:
|
||||
"""The PrimAITE app config directory path."""
|
||||
path = self._dirs.user_config_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_file_path(self) -> Path:
|
||||
def generate_app_config_file_path(self) -> Path:
|
||||
"""The PrimAITE app config file path."""
|
||||
return self.app_config_dir_path / "primaite_config.yaml"
|
||||
|
||||
@property
|
||||
def app_log_dir_path(self) -> Path:
|
||||
def generate_app_log_dir_path(self) -> Path:
|
||||
"""The PrimAITE app log directory path."""
|
||||
if sys.platform == "win32":
|
||||
path = self.app_home_path / "logs"
|
||||
@@ -100,8 +101,7 @@ class _PrimaitePaths:
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_log_file_path(self) -> Path:
|
||||
def generate_app_log_file_path(self) -> Path:
|
||||
"""The PrimAITE app log file path."""
|
||||
return self.app_log_dir_path / "primaite.log"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_steps: 0
|
||||
n_learn_episodes: 0
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
|
||||
@@ -2,7 +2,7 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_steps: 2560
|
||||
n_learn_episodes: 10
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
|
||||
@@ -2,7 +2,7 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_steps: 2560
|
||||
n_learn_episodes: 10
|
||||
n_eval_episodes: 0
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
|
||||
@@ -5,7 +5,6 @@ import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import nodeenv
|
||||
import pytest
|
||||
@@ -22,13 +21,15 @@ from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
||||
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
# PrimAITE v3 stuff
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
@@ -97,8 +98,9 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_primaite_session(request) -> TempPrimaiteSession:
|
||||
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
|
||||
"""Create a temporary PrimaiteSession object."""
|
||||
|
||||
config_path = request.param[0]
|
||||
return TempPrimaiteSession.from_config(config_path=config_path)
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
|
||||
config_path = request.param[0]
|
||||
return TempPrimaiteSession.from_config(config_path=config_path)
|
||||
|
||||
@@ -9,7 +9,7 @@ from primaite import getLogger
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
def temp_user_sessions_path() -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user