Fix typehint issues

This commit is contained in:
Marek Wolan
2023-07-31 17:07:56 +01:00
parent 0a079832e9
commit 9cf5bfa1b2
3 changed files with 14 additions and 13 deletions

View File

@@ -41,7 +41,7 @@ _TRAINING_CONFIG_PATH = _BENCHMARK_ROOT / "config" / "benchmark_training_config.
_LAY_DOWN_CONFIG_PATH = data_manipulation_config_path()
def get_size(size_bytes: int):
def get_size(size_bytes: int) -> str:
"""
Scale bytes to its proper format.
@@ -84,7 +84,7 @@ def _get_system_info() -> Dict:
def _build_benchmark_latex_report(
benchmark_metadata_dict: Dict, this_version_plot_path: Path, all_version_plot_path: Path
):
) -> None:
geometry_options = {"tmargin": "2.5cm", "rmargin": "2.5cm", "bmargin": "2.5cm", "lmargin": "2.5cm"}
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
@@ -186,7 +186,7 @@ class BenchmarkPrimaiteSession(PrimaiteSession):
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
) -> None:
super().__init__(training_config_path, lay_down_config_path)
self.setup()
@@ -195,10 +195,11 @@ class BenchmarkPrimaiteSession(PrimaiteSession):
"""Direct access to the env for ease of testing."""
return self._agent_session._env # noqa
def __enter__(self):
def __enter__(self) -> "BenchmarkPrimaiteSession":
return self
def __exit__(self, type, value, tb):
# TODO: typehints uncertain
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
shutil.rmtree(self.session_path)
_LOGGER.debug(f"Deleted benchmark session directory: {self.session_path}")
@@ -285,7 +286,7 @@ def _build_benchmark_results_dict(start_datetime: datetime, metadata_dict: Dict)
return averaged_data
def _get_df_from_episode_av_reward_dict(data: Dict):
def _get_df_from_episode_av_reward_dict(data: Dict) -> pl.DataFrame:
data: Dict = {"episode": data.keys(), "av_reward": data.values()}
return (
@@ -360,7 +361,7 @@ def _plot_benchmark_metadata(
return fig
def _plot_all_benchmarks_combined_session_av():
def _plot_all_benchmarks_combined_session_av() -> Figure:
"""
Plot the Benchmark results for each released version of PrimAITE.
@@ -410,7 +411,7 @@ def _plot_all_benchmarks_combined_session_av():
return fig
def run():
def run() -> NotImplementedError:
"""Run the PrimAITE benchmark."""
start_datetime = datetime.now()
av_reward_per_episode_dicts = {}