Make pytest patch with temporary session dir

This commit is contained in:
Marek Wolan
2023-11-16 15:11:03 +00:00
parent 829500a60f
commit 7545c25a46
6 changed files with 30 additions and 28 deletions

View File

@@ -29,6 +29,15 @@ class _PrimaitePaths:
def __init__(self) -> None:
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
self.user_home_path = self.generate_user_home_path()
self.user_sessions_path = self.generate_user_sessions_path()
self.user_config_path = self.generate_user_config_path()
self.user_notebooks_path = self.generate_user_notebooks_path()
self.app_home_path = self.generate_app_home_path()
self.app_config_dir_path = self.generate_app_config_dir_path()
self.app_config_file_path = self.generate_app_config_file_path()
self.app_log_dir_path = self.generate_app_log_dir_path()
self.app_log_file_path = self.generate_app_log_file_path()
def _get_dirs_properties(self) -> List[str]:
class_items = self.__class__.__dict__.items()
@@ -43,55 +52,47 @@ class _PrimaitePaths:
for p in self._get_dirs_properties():
getattr(self, p)
@property
def user_home_path(self) -> Path:
def generate_user_home_path(self) -> Path:
"""The PrimAITE user home path."""
path = Path.home() / "primaite" / __version__
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_sessions_path(self) -> Path:
def generate_user_sessions_path(self) -> Path:
"""The PrimAITE user sessions path."""
path = self.user_home_path / "sessions"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_config_path(self) -> Path:
def generate_user_config_path(self) -> Path:
"""The PrimAITE user config path."""
path = self.user_home_path / "config"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_notebooks_path(self) -> Path:
def generate_user_notebooks_path(self) -> Path:
"""The PrimAITE user notebooks path."""
path = self.user_home_path / "notebooks"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_home_path(self) -> Path:
def generate_app_home_path(self) -> Path:
"""The PrimAITE app home path."""
path = self._dirs.user_data_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_dir_path(self) -> Path:
def generate_app_config_dir_path(self) -> Path:
"""The PrimAITE app config directory path."""
path = self._dirs.user_config_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_file_path(self) -> Path:
def generate_app_config_file_path(self) -> Path:
"""The PrimAITE app config file path."""
return self.app_config_dir_path / "primaite_config.yaml"
@property
def app_log_dir_path(self) -> Path:
def generate_app_log_dir_path(self) -> Path:
"""The PrimAITE app log directory path."""
if sys.platform == "win32":
path = self.app_home_path / "logs"
@@ -100,8 +101,7 @@ class _PrimaitePaths:
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_log_file_path(self) -> Path:
def generate_app_log_file_path(self) -> Path:
"""The PrimAITE app log file path."""
return self.app_log_dir_path / "primaite.log"

View File

@@ -2,7 +2,7 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_steps: 0
n_learn_episodes: 0
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false

View File

@@ -2,7 +2,7 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_steps: 2560
n_learn_episodes: 10
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false

View File

@@ -2,7 +2,7 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_steps: 2560
n_learn_episodes: 10
n_eval_episodes: 0
max_steps_per_episode: 128
deterministic_eval: false

View File

@@ -5,7 +5,6 @@ import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
from unittest.mock import patch
import nodeenv
import pytest
@@ -22,13 +21,15 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
from primaite import PRIMAITE_PATHS
# PrimAITE v3 stuff
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.base import Node
@@ -97,8 +98,9 @@ class TempPrimaiteSession(PrimaiteSession):
@pytest.fixture
def temp_primaite_session(request) -> TempPrimaiteSession:
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
"""Create a temporary PrimaiteSession object."""
config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path)
with monkeypatch.context() as m:
m.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path)

View File

@@ -9,7 +9,7 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
def get_temp_session_path(session_timestamp: datetime) -> Path:
def temp_user_sessions_path() -> Path:
"""
Get a temp directory session path the test session will output to.