mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
fix(eval): comparison plot (#441)
* heatmap * filter comparison plots * latex style * curriculum heatmap * pre-commit * update figsize * large y-ticks * larger font * thinner * include 50
This commit is contained in:
committed by
GitHub
parent
f51769927e
commit
b843f33b1d
@@ -27,6 +27,7 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@@ -42,6 +43,22 @@ logging.basicConfig(
|
||||
logger = logging.getLogger("visualize_results")
|
||||
|
||||
|
||||
plt.rcParams.update(
|
||||
{
|
||||
"text.usetex": True,
|
||||
"font.family": "serif",
|
||||
"font.serif": ["Computer Modern Roman"],
|
||||
"text.latex.preamble": r"\usepackage{amsmath,amssymb,amsfonts,mathrsfs,bm}",
|
||||
"axes.labelsize": 20,
|
||||
"font.size": 20,
|
||||
"legend.fontsize": 14,
|
||||
"xtick.labelsize": 14,
|
||||
"ytick.labelsize": 14,
|
||||
"axes.titlesize": 22,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load_summaries(results_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""Load all summary.json files from subdirectories.
|
||||
|
||||
@@ -583,13 +600,14 @@ def create_comparison_plot(
|
||||
summaries: Dict[str, Dict[str, Any]],
|
||||
other_summaries: Dict[str, Dict[str, Any]],
|
||||
categories: Optional[Dict[str, List[str]]] = None,
|
||||
compare_model_ids: Optional[List[str]] = None,
|
||||
) -> Figure:
|
||||
"""
|
||||
Build a heat-map of per-category score differences (scaled to –100 … 100).
|
||||
|
||||
Rows : model IDs present in both `summaries` and `other_summaries`
|
||||
Cols : category names (`categories`)
|
||||
Value : 100 * (mean(score in summaries) − mean(score in other_summaries))
|
||||
Rows : category names (`categories`)
|
||||
Cols : model IDs present in both `summaries` and `other_summaries`
|
||||
Value : 100 × (mean(score in summaries) − mean(score in other_summaries))
|
||||
|
||||
A numeric annotation (rounded to 2 dp) is rendered in every cell.
|
||||
"""
|
||||
@@ -601,55 +619,53 @@ def create_comparison_plot(
|
||||
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
|
||||
categories = {"all": list(all_ds)}
|
||||
|
||||
# models appearing in both result sets
|
||||
# models present in both result sets
|
||||
common_models = [m for m in summaries if m in other_summaries]
|
||||
if not common_models:
|
||||
logger.error("No overlapping model IDs between the two result sets.")
|
||||
return plt.figure()
|
||||
|
||||
# sort models by overall performance
|
||||
overall_scores = {}
|
||||
for model_name, summary in summaries.items():
|
||||
scores = list(summary["dataset_best_scores"].values())
|
||||
overall_scores[model_name] = np.mean(scores)
|
||||
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
|
||||
common_models = [m for m in models if m in common_models]
|
||||
overall_scores = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
|
||||
models = [m for m, _ in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True) if m in common_models]
|
||||
if compare_model_ids:
|
||||
models = [m for m in models if m in compare_model_ids]
|
||||
|
||||
category_list = sorted(categories.keys())
|
||||
diff_matrix = np.zeros((len(common_models), len(category_list)))
|
||||
# ---------- note the transposed shape (categories × models)
|
||||
diff_matrix = np.zeros((len(category_list), len(models)))
|
||||
|
||||
# compute 100 × Δ
|
||||
for i, model in enumerate(common_models):
|
||||
cur_scores = summaries[model]["dataset_best_scores"]
|
||||
base_scores = other_summaries[model]["dataset_best_scores"]
|
||||
|
||||
for j, cat in enumerate(category_list):
|
||||
ds = categories[cat]
|
||||
for i, cat in enumerate(category_list):
|
||||
ds = categories[cat]
|
||||
for j, model in enumerate(models):
|
||||
cur_scores = summaries[model]["dataset_best_scores"]
|
||||
base_scores = other_summaries[model]["dataset_best_scores"]
|
||||
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||
diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100
|
||||
diff_matrix[i, j] = 100 * (cur_mean - base_mean)
|
||||
|
||||
# ---------------------------------------------------------------- Plot
|
||||
fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5)))
|
||||
# ---------------------------------------------------------------- plot
|
||||
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.5)))
|
||||
|
||||
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
|
||||
|
||||
# colour-bar
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom")
|
||||
cbar.ax.set_ylabel("$\Delta$ score (\%)", rotation=-90, va="bottom", fontweight="bold")
|
||||
|
||||
# ticks / labels
|
||||
ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right")
|
||||
ax.set_yticks(np.arange(len(common_models)), labels=common_models)
|
||||
ax.set_xticks(np.arange(len(models)), labels=models, rotation=45, ha="right")
|
||||
ax.set_yticks(np.arange(len(category_list)), labels=category_list)
|
||||
|
||||
# grid for readability
|
||||
ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True)
|
||||
ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True)
|
||||
ax.set_xticks(np.arange(-0.5, len(models), 1), minor=True)
|
||||
ax.set_yticks(np.arange(-0.5, len(category_list), 1), minor=True)
|
||||
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
|
||||
|
||||
# annotate each cell
|
||||
for i in range(len(common_models)):
|
||||
for j in range(len(category_list)):
|
||||
for i in range(len(category_list)):
|
||||
for j in range(len(models)):
|
||||
value = diff_matrix[i, j]
|
||||
ax.text(
|
||||
j,
|
||||
@@ -658,10 +674,10 @@ def create_comparison_plot(
|
||||
ha="center",
|
||||
va="center",
|
||||
color="black" if abs(value) < 50 else "white",
|
||||
fontsize=8,
|
||||
fontsize=12,
|
||||
)
|
||||
|
||||
ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14)
|
||||
# ax.set_title("Per-Category Performance $\Delta$ (hard − easy)", fontweight="bold")
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
@@ -702,6 +718,7 @@ def main():
|
||||
"--top-mode", default="hardest", choices=["hardest", "easiest", "variable"], help="Mode for top datasets plot"
|
||||
)
|
||||
parser.add_argument("--compare-results-dir", help="Directory to compare results with", default=None)
|
||||
parser.add_argument("--compare-model-ids", help="Comma-separated list of model IDs to compare", default=None)
|
||||
parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Output format for plots")
|
||||
parser.add_argument("--dpi", type=int, default=300, help="DPI for output images")
|
||||
parser.add_argument("--no-show", action="store_true", help="Don't display plots, just save them")
|
||||
@@ -773,7 +790,8 @@ def main():
|
||||
if not other_summaries:
|
||||
logger.error("No valid summaries found in comparison directory. Exiting.")
|
||||
return 1
|
||||
fig = create_comparison_plot(summaries, other_summaries, categories)
|
||||
compare_model_ids = args.compare_model_ids.split(",") if args.compare_model_ids else None
|
||||
fig = create_comparison_plot(summaries, other_summaries, categories, compare_model_ids)
|
||||
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
|
||||
|
||||
else:
|
||||
|
||||
150
notebooks/plot_curriculum.ipynb
Normal file
150
notebooks/plot_curriculum.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user