mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
comparison plot (#436)
This commit is contained in:
committed by
GitHub
parent
0cda6b1205
commit
5961a10145
@@ -587,76 +587,86 @@ def create_dashboard(summaries: Dict[str, Dict[str, Any]], categories: Dict[str,
|
||||
def create_comparison_plot(
|
||||
summaries: Dict[str, Dict[str, Any]],
|
||||
other_summaries: Dict[str, Dict[str, Any]],
|
||||
model_id: str,
|
||||
categories: Optional[Dict[str, List[str]]] = None,
|
||||
) -> Figure:
|
||||
"""Create a comparison plot between two models.
|
||||
"""
|
||||
Build a heat-map of per-category score differences (scaled to –100 … 100).
|
||||
|
||||
Args:
|
||||
summaries: Dictionary of model summaries
|
||||
other_summaries: Dictionary of other model summaries for comparison
|
||||
model_id: Model ID to compare with
|
||||
categories: Dictionary mapping categories to dataset lists
|
||||
Rows : model IDs present in both `summaries` and `other_summaries`
|
||||
Cols : category names (`categories`)
|
||||
Value : 100 * (mean(score in summaries) − mean(score in other_summaries))
|
||||
|
||||
Returns:
|
||||
Matplotlib figure
|
||||
A numeric annotation (rounded to 2 dp) is rendered in every cell.
|
||||
"""
|
||||
if not summaries or not other_summaries:
|
||||
logger.error("No summaries provided for comparison")
|
||||
return plt.figure()
|
||||
|
||||
current_scores, baseline_scores = {}, {}
|
||||
if categories is None:
|
||||
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
|
||||
categories = {"all": list(all_ds)}
|
||||
|
||||
# models appearing in both result sets
|
||||
common_models = [m for m in summaries if m in other_summaries]
|
||||
if not common_models:
|
||||
logger.error("No overlapping model IDs between the two result sets.")
|
||||
return plt.figure()
|
||||
|
||||
# sort models by overall performance
|
||||
overall_scores = {}
|
||||
for model_name, summary in summaries.items():
|
||||
if model_name == model_id:
|
||||
for category, datasets in categories.items():
|
||||
datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
|
||||
if datasets_scores: # Avoid division by zero
|
||||
current_scores[category] = np.mean(datasets_scores)
|
||||
else:
|
||||
current_scores[category] = 0
|
||||
scores = list(summary["dataset_best_scores"].values())
|
||||
overall_scores[model_name] = np.mean(scores)
|
||||
models = [item[0] for item in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)]
|
||||
common_models = [m for m in models if m in common_models]
|
||||
|
||||
for model_name, summary in other_summaries.items():
|
||||
if model_name == model_id:
|
||||
for category, datasets in categories.items():
|
||||
datasets_scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
|
||||
if datasets_scores:
|
||||
baseline_scores[category] = np.mean(datasets_scores)
|
||||
else:
|
||||
baseline_scores[category] = 0
|
||||
category_list = sorted(categories.keys())
|
||||
diff_matrix = np.zeros((len(common_models), len(category_list)))
|
||||
|
||||
logger.debug(f"Current scores: {current_scores}")
|
||||
logger.debug(f"Baseline scores: {baseline_scores}")
|
||||
# compute 100 × Δ
|
||||
for i, model in enumerate(common_models):
|
||||
cur_scores = summaries[model]["dataset_best_scores"]
|
||||
base_scores = other_summaries[model]["dataset_best_scores"]
|
||||
|
||||
# Create a bar chart for comparison
|
||||
fig, ax = plt.subplots(figsize=(20, 10))
|
||||
categories_list = sorted(current_scores.keys())
|
||||
current_values = [round(current_scores[cat] * 100, 2) for cat in categories_list]
|
||||
baseline_values = [round(baseline_scores[cat] * 100, 2) for cat in categories_list]
|
||||
for j, cat in enumerate(category_list):
|
||||
ds = categories[cat]
|
||||
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
|
||||
diff_matrix[i, j] = 100 * (cur_mean - base_mean) # scale to -100 … 100
|
||||
|
||||
x = np.arange(len(categories_list))
|
||||
width = 0.35
|
||||
colors = plt.cm.tab10.colors
|
||||
bars1 = ax.bar(x - width / 2, baseline_values, width, label="Baseline", color=colors[1])
|
||||
bars2 = ax.bar(x + width / 2, current_values, width, label="Difficult", color=colors[0])
|
||||
ax.set_ylabel("Average Score")
|
||||
ax.set_title(f"Performance of {model_id} when increasing difficulty", size=15)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(categories_list, rotation=45, ha="right")
|
||||
ax.legend()
|
||||
ax.set_ylim(0, max(max(current_values), max(baseline_values)) * 1.1)
|
||||
plt.tight_layout()
|
||||
plt.grid(axis="y")
|
||||
# ---------------------------------------------------------------- Plot
|
||||
fig, ax = plt.subplots(figsize=(max(8, len(category_list) * 1.2), max(6, len(common_models) * 0.5)))
|
||||
|
||||
# Add value labels on top of bars
|
||||
for bar in bars1:
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom")
|
||||
for bar in bars2:
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2f}", ha="center", va="bottom")
|
||||
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
|
||||
|
||||
plt.title(f"Performance of {model_id.split('/')[-1]} on Increasing Difficulty", size=15)
|
||||
# colour-bar
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.ax.set_ylabel("Δ score (percentage-points)", rotation=-90, va="bottom")
|
||||
|
||||
# ticks / labels
|
||||
ax.set_xticks(np.arange(len(category_list)), labels=category_list, rotation=45, ha="right")
|
||||
ax.set_yticks(np.arange(len(common_models)), labels=common_models)
|
||||
|
||||
# grid for readability
|
||||
ax.set_xticks(np.arange(-0.5, len(category_list), 1), minor=True)
|
||||
ax.set_yticks(np.arange(-0.5, len(common_models), 1), minor=True)
|
||||
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
|
||||
|
||||
# annotate each cell
|
||||
for i in range(len(common_models)):
|
||||
for j in range(len(category_list)):
|
||||
value = diff_matrix[i, j]
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
f"{value:.2f}",
|
||||
ha="center",
|
||||
va="center",
|
||||
color="black" if abs(value) < 50 else "white",
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_title("Per-Category Performance Δ (hard - easy)", fontsize=14)
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
@@ -768,18 +778,8 @@ def main():
|
||||
if not other_summaries:
|
||||
logger.error("No valid summaries found in comparison directory. Exiting.")
|
||||
return 1
|
||||
|
||||
comparison_output_dir = args.output_dir / "comparison"
|
||||
if not comparison_output_dir.exists():
|
||||
logger.info(f"Creating comparison output directory {comparison_output_dir}")
|
||||
comparison_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for model_name in summaries.keys():
|
||||
if model_name not in other_summaries:
|
||||
logger.warning(f"Model {model_name} not found in comparison directory. Skippping...")
|
||||
continue
|
||||
fig = create_comparison_plot(summaries, other_summaries, model_name, categories)
|
||||
save_figure(fig, comparison_output_dir, model_name, args.format, args.dpi)
|
||||
fig = create_comparison_plot(summaries, other_summaries, categories)
|
||||
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown plot type: {plot_type}")
|
||||
|
||||
Reference in New Issue
Block a user