Merge remote-tracking branch 'origin/dev' into feature/2735-usermanager-fixes

This commit is contained in:
Marek Wolan
2024-07-31 15:42:54 +01:00
3 changed files with 4 additions and 12 deletions

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
from report import build_benchmark_latex_report
from report import build_benchmark_md_report
from stable_baselines3 import PPO
import primaite
@@ -188,7 +188,7 @@ def run(
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
session_metadata_dict[i] = json.load(file)
# generate report
build_benchmark_latex_report(
build_benchmark_md_report(
benchmark_start_time=benchmark_start_time,
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),

View File

@@ -234,10 +234,7 @@ def _plot_av_s_per_100_steps_10_nodes(
"""
major_v = primaite.__version__.split(".")[0]
title = f"Performance of Minor and Bugfix Releases for Major Version {major_v}"
subtitle = (
f"Average Training Time per 100 Steps on 10 Nodes "
f"(target: <= {PLOT_CONFIG['av_s_per_100_steps_10_nodes_benchmark_threshold']} seconds)"
)
subtitle = "Average Training Time per 100 Steps on 10 Nodes "
title = f"{title} <br><sub>{subtitle}</sub>"
layout = go.Layout(
@@ -250,10 +247,6 @@ def _plot_av_s_per_100_steps_10_nodes(
versions = sorted(list(version_times_dict.keys()))
times = [version_times_dict[version] for version in versions]
av_s_per_100_steps_10_nodes_benchmark_threshold = PLOT_CONFIG["av_s_per_100_steps_10_nodes_benchmark_threshold"]
# Calculate the appropriate maximum y-axis value
max_y_axis_value = max(max(times), av_s_per_100_steps_10_nodes_benchmark_threshold) + 1
fig.add_trace(
go.Bar(
@@ -267,7 +260,6 @@ def _plot_av_s_per_100_steps_10_nodes(
fig.update_layout(
xaxis_title="PrimAITE Version",
yaxis_title="Avg Time per 100 Steps on 10 Nodes (seconds)",
yaxis=dict(range=[0, max_y_axis_value]),
title=title,
)

View File

@@ -52,7 +52,7 @@ license-files = ["LICENSE"]
[project.optional-dependencies]
rl = [
"ray[rllib] == 2.32.0, < 3",
"ray[rllib] >= 2.20.0, <2.33",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
"sb3-contrib==2.1.0",