mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
add numbers to performance heatmap (#442)
This commit is contained in:
committed by
GitHub
parent
b843f33b1d
commit
6614338ecc
@@ -392,7 +392,7 @@ def create_performance_heatmap(
|
||||
|
||||
Rows : models (sorted by overall mean score, high→low)
|
||||
Cols : datasets grouped by `categories`
|
||||
Cell : 100 × raw score
|
||||
Cell : 100 × raw score (value shown inside each cell)
|
||||
"""
|
||||
if not summaries:
|
||||
logger.error("No summaries provided")
|
||||
@@ -415,12 +415,11 @@ def create_performance_heatmap(
|
||||
|
||||
# ---- plot
|
||||
fig, ax = plt.subplots(figsize=(max(20, len(all_datasets) * 0.25), max(8, len(models) * 0.5)))
|
||||
|
||||
im = ax.imshow(score_matrix, cmap="YlOrRd", aspect="auto", vmin=0, vmax=100)
|
||||
|
||||
# colour-bar
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.ax.set_ylabel("Score (%)", rotation=-90, va="bottom")
|
||||
cbar.ax.set_ylabel("Score (\%)", rotation=-90, va="bottom")
|
||||
|
||||
# ticks & labels
|
||||
ax.set_xticks(np.arange(len(all_datasets)))
|
||||
@@ -430,7 +429,7 @@ def create_performance_heatmap(
|
||||
|
||||
# category separators & titles
|
||||
current = 0
|
||||
label_offset = -0.25 # ↓ push labels down (was around −0.7)
|
||||
label_offset = -0.25
|
||||
for cat, ds in sorted(categories.items()):
|
||||
if not ds:
|
||||
continue
|
||||
@@ -441,7 +440,7 @@ def create_performance_heatmap(
|
||||
mid = current + len(ds) / 2 - 0.5
|
||||
ax.text(
|
||||
mid,
|
||||
label_offset, # <-- use offset
|
||||
label_offset,
|
||||
cat,
|
||||
ha="center",
|
||||
va="top",
|
||||
@@ -455,7 +454,22 @@ def create_performance_heatmap(
|
||||
ax.set_yticks(np.arange(-0.5, len(models), 1), minor=True)
|
||||
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
|
||||
|
||||
# ax.set_title("Model Performance by Dataset", fontsize=15)
|
||||
# ---- annotate every cell with its value
|
||||
for i in range(len(models)):
|
||||
for j in range(len(all_datasets)):
|
||||
val = score_matrix[i, j]
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
f"{val:.1f}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
rotation=-90, # 90° clockwise
|
||||
rotation_mode="anchor", # keep anchor point fixed
|
||||
color="white" if val >= 50 else "black",
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
@@ -646,7 +660,7 @@ def create_comparison_plot(
|
||||
diff_matrix[i, j] = 100 * (cur_mean - base_mean)
|
||||
|
||||
# ---------------------------------------------------------------- plot
|
||||
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.5)))
|
||||
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.58)))
|
||||
|
||||
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user