add numbers to performance heatmap (#442)

This commit is contained in:
Zafir Stojanovski
2025-05-30 18:39:13 +02:00
committed by GitHub
parent b843f33b1d
commit 6614338ecc

View File

@@ -392,7 +392,7 @@ def create_performance_heatmap(
Rows : models (sorted by overall mean score, high→low)
Cols : datasets grouped by `categories`
Cell : 100 × raw score
Cell : 100 × raw score (value shown inside each cell)
"""
if not summaries:
logger.error("No summaries provided")
@@ -415,12 +415,11 @@ def create_performance_heatmap(
# ---- plot
fig, ax = plt.subplots(figsize=(max(20, len(all_datasets) * 0.25), max(8, len(models) * 0.5)))
im = ax.imshow(score_matrix, cmap="YlOrRd", aspect="auto", vmin=0, vmax=100)
# colour-bar
cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Score (%)", rotation=-90, va="bottom")
cbar.ax.set_ylabel("Score (\%)", rotation=-90, va="bottom")
# ticks & labels
ax.set_xticks(np.arange(len(all_datasets)))
@@ -430,7 +429,7 @@ def create_performance_heatmap(
# category separators & titles
current = 0
label_offset = -0.25 # ↓ push labels down (was around 0.7)
label_offset = -0.25
for cat, ds in sorted(categories.items()):
if not ds:
continue
@@ -441,7 +440,7 @@ def create_performance_heatmap(
mid = current + len(ds) / 2 - 0.5
ax.text(
mid,
label_offset, # <-- use offset
label_offset,
cat,
ha="center",
va="top",
@@ -455,7 +454,22 @@ def create_performance_heatmap(
ax.set_yticks(np.arange(-0.5, len(models), 1), minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
# ax.set_title("Model Performance by Dataset", fontsize=15)
# ---- annotate every cell with its value
for i in range(len(models)):
for j in range(len(all_datasets)):
val = score_matrix[i, j]
ax.text(
j,
i,
f"{val:.1f}",
ha="center",
va="center",
fontsize=7,
rotation=-90, # 90° clockwise
rotation_mode="anchor", # keep anchor point fixed
color="white" if val >= 50 else "black",
)
plt.tight_layout()
return fig
@@ -646,7 +660,7 @@ def create_comparison_plot(
diff_matrix[i, j] = 100 * (cur_mean - base_mean)
# ---------------------------------------------------------------- plot
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.5)))
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.58)))
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)