fix(eval): comparison plot (#441)

* heatmap

* filter comparison plots

* latex style

* curriculum heatmap

* pre-commit

* update figsize

* large y-ticks

* larger font

* thinner

* include 50
This commit is contained in:
Zafir Stojanovski
2025-05-29 12:31:07 +02:00
committed by GitHub
parent f51769927e
commit b843f33b1d
2 changed files with 198 additions and 30 deletions

View File

@@ -27,6 +27,7 @@ import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
@@ -42,6 +43,22 @@ logging.basicConfig(
logger = logging.getLogger("visualize_results")
plt.rcParams.update(
{
"text.usetex": True,
"font.family": "serif",
"font.serif": ["Computer Modern Roman"],
"text.latex.preamble": r"\usepackage{amsmath,amssymb,amsfonts,mathrsfs,bm}",
"axes.labelsize": 20,
"font.size": 20,
"legend.fontsize": 14,
"xtick.labelsize": 14,
"ytick.labelsize": 14,
"axes.titlesize": 22,
}
)
def load_summaries(results_dir: str) -> Dict[str, Dict[str, Any]]:
"""Load all summary.json files from subdirectories.
@@ -583,13 +600,14 @@ def create_comparison_plot(
summaries: Dict[str, Dict[str, Any]],
other_summaries: Dict[str, Dict[str, Any]],
categories: Optional[Dict[str, List[str]]] = None,
compare_model_ids: Optional[List[str]] = None,
) -> Figure:
"""
Build a heat-map of per-category score differences (scaled to 100 … 100).
Rows : model IDs present in both `summaries` and `other_summaries`
Cols : category names (`categories`)
Value : 100 * (mean(score in summaries) mean(score in other_summaries))
Rows : category names (`categories`)
Cols : model IDs present in both `summaries` and `other_summaries`
Value : 100 × (mean(score in summaries) mean(score in other_summaries))
A numeric annotation (rounded to 2 dp) is rendered in every cell.
"""
@@ -601,55 +619,53 @@ def create_comparison_plot(
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
categories = {"all": list(all_ds)}
# models appearing in both result sets
# models present in both result sets
common_models = [m for m in summaries if m in other_summaries]
if not common_models:
logger.error("No overlapping model IDs between the two result sets.")
return plt.figure()
# sort models by overall performance
overall_scores = {}
for model_name, summary in summaries.items():
scores = list(summary["dataset_best_scores"].values())
overall_scores[model_name] = np.mean(scores)
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
common_models = [m for m in models if m in common_models]
overall_scores = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
models = [m for m, _ in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True) if m in common_models]
if compare_model_ids:
models = [m for m in models if m in compare_model_ids]
category_list = sorted(categories.keys())
diff_matrix = np.zeros((len(common_models), len(category_list)))
# ---------- note the transposed shape (categories × models)
diff_matrix = np.zeros((len(category_list), len(models)))
# compute 100 × Δ
for i, model in enumerate(common_models):
cur_scores = summaries[model]["dataset_best_scores"]
base_scores = other_summaries[model]["dataset_best_scores"]
for j, cat in enumerate(category_list):
ds = categories[cat]
for i, cat in enumerate(category_list):
ds = categories[cat]
for j, model in enumerate(models):
cur_scores = summaries[model]["dataset_best_scores"]
base_scores = other_summaries[model]["dataset_best_scores"]
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100
diff_matrix[i, j] = 100 * (cur_mean - base_mean)
# ---------------------------------------------------------------- Plot
fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5)))
# ---------------------------------------------------------------- plot
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.5)))
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
# colour-bar
cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom")
cbar.ax.set_ylabel("$\Delta$ score (\%)", rotation=-90, va="bottom", fontweight="bold")
# ticks / labels
ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right")
ax.set_yticks(np.arange(len(common_models)), labels=common_models)
ax.set_xticks(np.arange(len(models)), labels=models, rotation=45, ha="right")
ax.set_yticks(np.arange(len(category_list)), labels=category_list)
# grid for readability
ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True)
ax.set_xticks(np.arange(-0.5, len(models), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(category_list), 1), minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
# annotate each cell
for i in range(len(common_models)):
for j in range(len(category_list)):
for i in range(len(category_list)):
for j in range(len(models)):
value = diff_matrix[i, j]
ax.text(
j,
@@ -658,10 +674,10 @@ def create_comparison_plot(
ha="center",
va="center",
color="black" if abs(value) < 50 else "white",
fontsize=8,
fontsize=12,
)
ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14)
# ax.set_title("Per-Category Performance $\Delta$ (hard easy)", fontweight="bold")
plt.tight_layout()
return fig
@@ -702,6 +718,7 @@ def main():
"--top-mode", default="hardest", choices=["hardest", "easiest", "variable"], help="Mode for top datasets plot"
)
parser.add_argument("--compare-results-dir", help="Directory to compare results with", default=None)
parser.add_argument("--compare-model-ids", help="Comma-separated list of model IDs to compare", default=None)
parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Output format for plots")
parser.add_argument("--dpi", type=int, default=300, help="DPI for output images")
parser.add_argument("--no-show", action="store_true", help="Don't display plots, just save them")
@@ -773,7 +790,8 @@ def main():
if not other_summaries:
logger.error("No valid summaries found in comparison directory. Exiting.")
return 1
fig = create_comparison_plot(summaries, other_summaries, categories)
compare_model_ids = args.compare_model_ids.split(",") if args.compare_model_ids else None
fig = create_comparison_plot(summaries, other_summaries, categories, compare_model_ids)
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
else:

File diff suppressed because one or more lines are too long