diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index ead5723b..8a911720 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -41,7 +41,7 @@ _TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config. _LAY_DOWN_CONFIG_PATH = data_manipulation_config_path() -def get_size(size_bytes: int): +def get_size(size_bytes: int) -> str: """ Scale bytes to its proper format. @@ -84,7 +84,7 @@ def _get_system_info() -> Dict: def _build_benchmark_latex_report( benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path -): +) -> None: geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"} data = benchmark_metadata_dict primaite_version = data["primaite_version"] @@ -186,7 +186,7 @@ class BenchmarkPrimaiteSession(PrimaiteSession): self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], - ): + ) -> None: super().__init__(training_config_path, lay_down_config_path) self.setup() @@ -195,10 +195,11 @@ class BenchmarkPrimaiteSession(PrimaiteSession): """Direct access to the env for ease of testing.""" return self._agent_session._env # noqa - def __enter__(self): + def __enter__(self) -> "BenchmarkPrimaiteSession": return self - def __exit__(self, type, value, tb): + # TODO: typehints uncertain + def __exit__(self, type: Any, value: Any, tb: Any) -> None: shutil.rmtree(self.session_path) _LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}") @@ -285,7 +286,7 @@ def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict) return averaged_data -def _get_df_from_episode_av_reward_dict(data: Dict): +def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame: data: Dict = {"episode": data.keys(), "av_reward": data.values()} return ( @@ -360,7 +361,7 @@ def _plot_benchmark_metadata( return fig -def _plot_all_benchmarks_combined_session_av(): +def _plot_all_benchmarks_combined_session_av() -> Figure: """ Plot the Benchmark results for each released version of PrimAITE. @@ -410,7 +411,7 @@ def _plot_all_benchmarks_combined_session_av(): return fig -def run(): +def run() -> NotImplementedError: """Run the PrimAITE benchmark.""" start_datetime = datetime.now() av_reward_per_episode_dicts = {} diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index a0f5b7fe..ad157c9c 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -24,14 +24,14 @@ class _PrimaitePaths: The PlatformDirs appname is 'primaite' and the version is ``primaite.__version__`. """ - def __init__(self): + def __init__(self) -> None: self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__) def _get_dirs_properties(self) -> List[str]: class_items = self.__class__.__dict__.items() return [k for k, v in class_items if isinstance(v, property)] - def mkdirs(self): + def mkdirs(self) -> None: """ Creates all Primaite directories. @@ -102,7 +102,7 @@ class _PrimaitePaths: """The PrimAITE app log file path.""" return self.app_log_dir_path / "primaite.log" - def __repr__(self): + def __repr__(self) -> str: properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()]) return f"{self.__class__.__name__}({properties_str})" @@ -110,7 +110,7 @@ class _PrimaitePaths: PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths() -def _host_primaite_config(): +def _host_primaite_config() -> None: if not PRIMAITE_PATHS.app_config_file_path.exists(): pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 383a9b5a..be80374b 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -443,7 +443,7 @@ class AccessControlList(AbstractObservationComponent): _DATA_TYPE: type = np.int64 - def __init__(self, env: "Primaite"): + def __init__(self, env: "Primaite") -> None: """ Initialise an AccessControlList observation component.