comparison plot (#436)

This commit is contained in:
Zafir Stojanovski
2025-05-18 23:57:49 +02:00
committed by GitHub
parent 0cda6b1205
commit 5961a10145

View File

@@ -587,76 +587,86 @@ def create_dashboard(summaries: Dict[str, Dict[str, Any]], categories: Dict[str,
def create_comparison_plot(
summaries: Dict[str, Dict[str, Any]],
other_summaries: Dict[str, Dict[str, Any]],
model_id: str,
categories: Optional[Dict[str, List[str]]] = None,
) -> Figure:
"""Create a comparison plot between two models.
"""
Build a heat-map of per-category score differences (scaled to 100 … 100).
Args:
summaries: Dictionary of model summaries
other_summaries: Dictionary of other model summaries for comparison
model_id: Model ID to compare with
categories: Dictionary mapping categories to dataset lists
Rows : model IDs present in both `summaries` and `other_summaries`
Cols : category names (`categories`)
Value : 100 * (mean(score in summaries) mean(score in other_summaries))
Returns:
Matplotlib figure
A numeric annotation (rounded to 2 dp) is rendered in every cell.
"""
if not summaries or not other_summaries:
logger.error("No summaries provided for comparison")
return plt.figure()
current_scores, baseline_scores = {}, {}
if categories is None:
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
categories = {"all": list(all_ds)}
# models appearing in both result sets
common_models = [m for m in summaries if m in other_summaries]
if not common_models:
logger.error("No overlapping model IDs between the two result sets.")
return plt.figure()
# sort models by overall performance
overall_scores = {}
for model_name, summary in summaries.items():
if model_name == model_id:
for category, datasets in categories.items():
datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
if datasets_scores: # Avoid division by zero
current_scores[category] = np.mean(datasets_scores)
else:
current_scores[category] = 0
scores = list(summary["dataset_best_scores"].values())
overall_scores[model_name] = np.mean(scores)
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
common_models = [m for m in models if m in common_models]
for model_name, summary in other_summaries.items():
if model_name == model_id:
for category, datasets in categories.items():
datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
if datasets_scores:
baseline_scores[category] = np.mean(datasets_scores)
else:
baseline_scores[category] = 0
category_list = sorted(categories.keys())
diff_matrix = np.zeros((len(common_models), len(category_list)))
logger.debug(f"Current scores: {current_scores}")
logger.debug(f"Baseline scores: {baseline_scores}")
# compute 100 × Δ
for i, model in enumerate(common_models):
cur_scores = summaries[model]["dataset_best_scores"]
base_scores = other_summaries[model]["dataset_best_scores"]
# Create a bar chart for comparison
fig, ax = plt.subplots(figsize=(20, 10))
categories_list = sorted(current_scores.keys())
current_values = [round(current_scores[cat] * 100, 2) for cat in categories_list]
baseline_values = [round(baseline_scores[cat] * 100, 2) for cat in categories_list]
for j, cat in enumerate(category_list):
ds = categories[cat]
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100
x = np.arange(len(categories_list))
width = 0.35
colors = plt.cm.tab10.colors
bars1 = ax.bar(x - width / 2, baseline_values, width, label="Baseline", color=colors[1])
bars2 = ax.bar(x + width / 2, current_values, width, label="Difficult", color=colors[0])
ax.set_ylabel("Average Score")
ax.set_title(f"Performance of {model_id} when increasing difficulty", size=15)
ax.set_xticks(x)
ax.set_xticklabels(categories_list, rotation=45, ha="right")
ax.legend()
ax.set_ylim(0, max(max(current_values), max(baseline_values)) * 1.1)
plt.tight_layout()
plt.grid(axis="y")
# ---------------------------------------------------------------- Plot
fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5)))
# Add value labels on top of bars
for bar in bars1:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom")
for bar in bars2:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom")
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
plt.title(f"Performance of {model_id.split('/')[-1]} on Increasing Difficulty", size=15)
# colour-bar
cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom")
# ticks / labels
ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right")
ax.set_yticks(np.arange(len(common_models)), labels=common_models)
# grid for readability
ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
# annotate each cell
for i in range(len(common_models)):
for j in range(len(category_list)):
value = diff_matrix[i, j]
ax.text(
j,
i,
f"{value:.2f}",
ha="center",
va="center",
color="black" if abs(value) < 50 else "white",
fontsize=8,
)
ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14)
plt.tight_layout()
return fig
@@ -768,18 +778,8 @@ def main():
if not other_summaries:
logger.error("No valid summaries found in comparison directory. Exiting.")
return 1
comparison_output_dir = args.output_dir / "comparison"
if not comparison_output_dir.exists():
logger.info(f"Creating comparison output directory {comparison_output_dir}")
comparison_output_dir.mkdir(parents=True, exist_ok=True)
for model_name in summaries.keys():
if model_name not in other_summaries:
logger.warning(f"Model {model_name} not found in comparison directory. Skippping...")
continue
fig = create_comparison_plot(summaries, other_summaries, model_name, categories)
save_figure(fig, comparison_output_dir, model_name, args.format, args.dpi)
fig = create_comparison_plot(summaries, other_summaries, categories)
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
else:
logger.warning(f"Unknown plot type: {plot_type}")