Merged PR 151: Check type hints with pre-commit
## Summary Added `flake8-annotations` to the pre-commit hooks. This ensures that we all write type hints for all new code. There's also a minor unrelated addition to the pre-commit template. ## Test process I tried adding a function with a parameter but no typehint. Git did not allow me to commit this. ## Checklist - [x] This PR is linked to a **work item** - [x] I have performed **self-review** of the code - [x] I have written **tests** for any new functionality added with this PR - [x] I have updated the **documentation** if this PR changes or adds functionality - [x] I have run **pre-commit** checks for code style Related work items: #1721
This commit is contained in:
@@ -9,4 +9,5 @@
|
||||
- [ ] I have performed **self-review** of the code
|
||||
- [ ] I have written **tests** for any new functionality added with this PR
|
||||
- [ ] I have updated the **documentation** if this PR changes or adds functionality
|
||||
- [ ] I have written/updated **design docs** if this PR implements new functionality.
|
||||
- [ ] I have run **pre-commit** checks for code style
|
||||
|
||||
3
.flake8
3
.flake8
@@ -9,5 +9,8 @@ extend-ignore =
|
||||
E712
|
||||
D401
|
||||
F811
|
||||
ANN101
|
||||
ANN102
|
||||
exclude =
|
||||
docs/source/*
|
||||
tests/*
|
||||
|
||||
@@ -27,3 +27,4 @@ repos:
|
||||
- id: flake8
|
||||
additional_dependencies:
|
||||
- flake8-docstrings
|
||||
- flake8-annotations
|
||||
|
||||
@@ -41,7 +41,7 @@ _TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config.
|
||||
_LAY_DOWN_CONFIG_PATH = data_manipulation_config_path()
|
||||
|
||||
|
||||
def get_size(size_bytes: int):
|
||||
def get_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Scale bytes to its proper format.
|
||||
|
||||
@@ -84,7 +84,7 @@ def _get_system_info() -> Dict:
|
||||
|
||||
def _build_benchmark_latex_report(
|
||||
benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path
|
||||
):
|
||||
) -> None:
|
||||
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
|
||||
data = benchmark_metadata_dict
|
||||
primaite_version = data["primaite_version"]
|
||||
@@ -186,7 +186,7 @@ class BenchmarkPrimaiteSession(PrimaiteSession):
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self.setup()
|
||||
|
||||
@@ -195,10 +195,11 @@ class BenchmarkPrimaiteSession(PrimaiteSession):
|
||||
"""Direct access to the env for ease of testing."""
|
||||
return self._agent_session._env # noqa
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "BenchmarkPrimaiteSession":
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
# TODO: typehints uncertain
|
||||
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
|
||||
shutil.rmtree(self.session_path)
|
||||
_LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}")
|
||||
|
||||
@@ -285,7 +286,7 @@ def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict)
|
||||
return averaged_data
|
||||
|
||||
|
||||
def _get_df_from_episode_av_reward_dict(data: Dict):
|
||||
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
|
||||
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
|
||||
|
||||
return (
|
||||
@@ -360,7 +361,7 @@ def _plot_benchmark_metadata(
|
||||
return fig
|
||||
|
||||
|
||||
def _plot_all_benchmarks_combined_session_av():
|
||||
def _plot_all_benchmarks_combined_session_av() -> Figure:
|
||||
"""
|
||||
Plot the Benchmark results for each released version of PrimAITE.
|
||||
|
||||
@@ -410,7 +411,7 @@ def _plot_all_benchmarks_combined_session_av():
|
||||
return fig
|
||||
|
||||
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""Run the PrimAITE benchmark."""
|
||||
start_datetime = datetime.now()
|
||||
av_reward_per_episode_dicts = {}
|
||||
|
||||
@@ -54,6 +54,7 @@ license-files = ["LICENSE"]
|
||||
dev = [
|
||||
"build==0.10.0",
|
||||
"flake8==6.0.0",
|
||||
"flake8-annotations",
|
||||
"furo==2023.3.27",
|
||||
"gputil==1.4.0",
|
||||
"pip-licenses==4.3.0",
|
||||
|
||||
@@ -24,14 +24,14 @@ class _PrimaitePaths:
|
||||
The PlatformDirs appname is 'primaite' and the version is ``primaite.__version__`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
|
||||
|
||||
def _get_dirs_properties(self) -> List[str]:
|
||||
class_items = self.__class__.__dict__.items()
|
||||
return [k for k, v in class_items if isinstance(v, property)]
|
||||
|
||||
def mkdirs(self):
|
||||
def mkdirs(self) -> None:
|
||||
"""
|
||||
Creates all Primaite directories.
|
||||
|
||||
@@ -102,7 +102,7 @@ class _PrimaitePaths:
|
||||
"""The PrimAITE app log file path."""
|
||||
return self.app_log_dir_path / "primaite.log"
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()])
|
||||
return f"{self.__class__.__name__}({properties_str})"
|
||||
|
||||
@@ -110,7 +110,7 @@ class _PrimaitePaths:
|
||||
PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
|
||||
|
||||
|
||||
def _host_primaite_config():
|
||||
def _host_primaite_config() -> None:
|
||||
if not PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
|
||||
|
||||
@@ -443,7 +443,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user