diff --git a/.azuredevops/pull_request_template.md b/.azuredevops/pull_request_template.md index 5ff03e18..fd28ed57 100644 --- a/.azuredevops/pull_request_template.md +++ b/.azuredevops/pull_request_template.md @@ -9,4 +9,5 @@ - [ ] I have performed **self-review** of the code - [ ] I have written **tests** for any new functionality added with this PR - [ ] I have updated the **documentation** if this PR changes or adds functionality +- [ ] I have written/updated **design docs** if this PR implements new functionality. - [ ] I have run **pre-commit** checks for code style diff --git a/.flake8 b/.flake8 index 398d14fb..6e653102 100644 --- a/.flake8 +++ b/.flake8 @@ -9,5 +9,8 @@ extend-ignore = E712 D401 F811 + ANN101 + ANN102 exclude = docs/source/* + tests/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e435bee..494ea937 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,3 +27,4 @@ repos: - id: flake8 additional_dependencies: - flake8-docstrings + - flake8-annotations diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index ead5723b..9fec5711 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -41,7 +41,7 @@ _TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config. _LAY_DOWN_CONFIG_PATH = data_manipulation_config_path() -def get_size(size_bytes: int): +def get_size(size_bytes: int) -> str: """ Scale bytes to its proper format. @@ -84,7 +84,7 @@ def _get_system_info() -> Dict: def _build_benchmark_latex_report( benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path -): +) -> None: geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"} data = benchmark_metadata_dict primaite_version = data["primaite_version"] @@ -186,7 +186,7 @@ class BenchmarkPrimaiteSession(PrimaiteSession): self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], - ): + ) -> None: super().__init__(training_config_path, lay_down_config_path) self.setup() @@ -195,10 +195,11 @@ class BenchmarkPrimaiteSession(PrimaiteSession): """Direct access to the env for ease of testing.""" return self._agent_session._env # noqa - def __enter__(self): + def __enter__(self) -> "BenchmarkPrimaiteSession": return self - def __exit__(self, type, value, tb): + # TODO: typehints uncertain + def __exit__(self, type: Any, value: Any, tb: Any) -> None: shutil.rmtree(self.session_path) _LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}") @@ -285,7 +286,7 @@ def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict) return averaged_data -def _get_df_from_episode_av_reward_dict(data: Dict): +def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame: data: Dict = {"episode": data.keys(), "av_reward": data.values()} return ( @@ -360,7 +361,7 @@ def _plot_benchmark_metadata( return fig -def _plot_all_benchmarks_combined_session_av(): +def _plot_all_benchmarks_combined_session_av() -> Figure: """ Plot the Benchmark results for each released version of PrimAITE. @@ -410,7 +411,7 @@ def _plot_all_benchmarks_combined_session_av(): return fig -def run(): +def run() -> None: """Run the PrimAITE benchmark.""" start_datetime = datetime.now() av_reward_per_episode_dicts = {} diff --git a/pyproject.toml b/pyproject.toml index 4e8250d8..4982dfd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ license-files = ["LICENSE"] dev = [ "build==0.10.0", "flake8==6.0.0", + "flake8-annotations", "furo==2023.3.27", "gputil==1.4.0", "pip-licenses==4.3.0", diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index a0f5b7fe..ad157c9c 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -24,14 +24,14 @@ class _PrimaitePaths: The PlatformDirs appname is 'primaite' and the version is ``primaite.__version__`. """ - def __init__(self): + def __init__(self) -> None: self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__) def _get_dirs_properties(self) -> List[str]: class_items = self.__class__.__dict__.items() return [k for k, v in class_items if isinstance(v, property)] - def mkdirs(self): + def mkdirs(self) -> None: """ Creates all Primaite directories. @@ -102,7 +102,7 @@ class _PrimaitePaths: """The PrimAITE app log file path.""" return self.app_log_dir_path / "primaite.log" - def __repr__(self): + def __repr__(self) -> str: properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()]) return f"{self.__class__.__name__}({properties_str})" @@ -110,7 +110,7 @@ class _PrimaitePaths: PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths() -def _host_primaite_config(): +def _host_primaite_config() -> None: if not PRIMAITE_PATHS.app_config_file_path.exists(): pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 383a9b5a..be80374b 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -443,7 +443,7 @@ class AccessControlList(AbstractObservationComponent): _DATA_TYPE: type = np.int64 - def __init__(self, env: "Primaite"): + def __init__(self, env: "Primaite") -> None: """ Initialise an AccessControlList observation component.