Merge remote-tracking branch 'origin/release/3.0.0' into merge-3.0.0-to-dev
@@ -1,4 +1,5 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -13,15 +14,14 @@ from primaite.config.load import data_manipulation_config_path
|
||||
|
||||
_LOGGER = primaite.getLogger(__name__)
|
||||
|
||||
_BENCHMARK_ROOT = Path(__file__).parent
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
|
||||
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
_MAJOR_V = primaite.__version__.split(".")[0]
|
||||
|
||||
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
|
||||
# Clear and recreate the output directory
|
||||
if _OUTPUT_ROOT.exists():
|
||||
shutil.rmtree(_OUTPUT_ROOT)
|
||||
_OUTPUT_ROOT.mkdir()
|
||||
_BENCHMARK_ROOT = Path(__file__).parent
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results" / f"v{_MAJOR_V}"
|
||||
_VERSION_ROOT: Final[Path] = _RESULTS_ROOT / f"v{primaite.__version__}"
|
||||
_SESSION_METADATA_ROOT: Final[Path] = _VERSION_ROOT / "session_metadata"
|
||||
|
||||
_SESSION_METADATA_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class BenchmarkSession:
|
||||
@@ -51,9 +51,6 @@ class BenchmarkSession:
|
||||
end_time: datetime
|
||||
"""End time for the session."""
|
||||
|
||||
session_metadata: Dict
|
||||
"""Dict containing the metadata for the session - used to generate benchmark report."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gym_env: BenchmarkPrimaiteGymEnv,
|
||||
@@ -182,8 +179,14 @@ def run(
|
||||
learning_rate=learning_rate,
|
||||
)
|
||||
session.train()
|
||||
session_metadata_dict[i] = session.session_metadata
|
||||
|
||||
# Dump the session metadata so that we're not holding it in memory as it's large
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "w") as file:
|
||||
json.dump(session.session_metadata, file, indent=4)
|
||||
|
||||
for i in range(1, number_of_sessions + 1):
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
|
||||
session_metadata_dict[i] = json.load(file)
|
||||
# generate report
|
||||
build_benchmark_latex_report(
|
||||
benchmark_start_time=benchmark_start_time,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -124,7 +123,7 @@ def _plot_benchmark_metadata(
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
|
||||
@@ -140,7 +139,8 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
"""
|
||||
title = "PrimAITE Versions Learning Benchmark"
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
@@ -172,7 +172,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
fig["data"][0]["showlegend"] = True
|
||||
@@ -188,8 +188,6 @@ def build_benchmark_latex_report(
|
||||
v_str = f"v{primaite.__version__}"
|
||||
|
||||
version_result_dir = results_root_path / v_str
|
||||
if version_result_dir.exists():
|
||||
shutil.rmtree(version_result_dir)
|
||||
version_result_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# load the config file as dict
|
||||
@@ -200,7 +198,7 @@ def build_benchmark_latex_report(
|
||||
benchmark_metadata_dict = _build_benchmark_results_dict(
|
||||
start_datetime=benchmark_start_time, metadata_dict=session_metadata, config=cfg_data
|
||||
)
|
||||
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
|
||||
json.dump(benchmark_metadata_dict, file, indent=4)
|
||||
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
|
||||
@@ -241,9 +239,9 @@ def build_benchmark_latex_report(
|
||||
f"with each episode consisting of {steps} steps."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe mean reward per episode from each session is captured. This is then used to calculate a "
|
||||
f"combined average reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the combined average reward per session is calculated for "
|
||||
f"\nThe total reward per episode from each session is captured. This is then used to calculate an "
|
||||
f"caverage total reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the average total reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
)
|
||||
|
||||
@@ -294,14 +292,14 @@ def build_benchmark_latex_report(
|
||||
table.add_hline()
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"PrimAITE {primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
|
||||
with doc.create(Subsection("PrimAITE All Versions Learning Benchmark Plot")):
|
||||
with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption("PrimAITE All Versions Learning Benchmark Plot")
|
||||
pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
|
||||
|
Before Width: | Height: | Size: 90 KiB |
BIN
benchmark/results/v2/PrimAITE Versions Learning Benchmark.png
Normal file
|
After Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 225 KiB After Width: | Height: | Size: 225 KiB |
|
Before Width: | Height: | Size: 296 KiB |
BIN
benchmark/results/v3/PrimAITE Versions Learning Benchmark.png
Normal file
|
After Width: | Height: | Size: 91 KiB |
|
After Width: | Height: | Size: 295 KiB |