Files
reasoning-gym/eval/visualize_results.py
2025-05-30 18:39:13 +02:00

837 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
"""
Visualization script for reasoning gym evaluation results.
This script generates visualizations from evaluation results stored in summary.json files.
Usage:
python visualize_results.py --results-dir results/ [options]
Options:
--output-dir DIR Directory to save visualizations (default: visualizations)
--plots PLOTS Comma-separated list of plots to generate (default: all)
Available: radar,bar,violin,heatmap,dashboard,distribution,top_datasets
--top-n N Number of datasets to show in top datasets plot (default: 15)
--top-mode MODE Mode for top datasets plot: hardest, easiest, variable (default: hardest)
--format FORMAT Output format for plots: png, pdf, svg (default: png)
--dpi DPI DPI for output images (default: 300)
--no-show Don't display plots, just save them
--debug Enable debug logging
"""
import argparse
import json
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from matplotlib.patches import Patch
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger("visualize_results")
plt.rcParams.update(
{
"text.usetex": True,
"font.family": "serif",
"font.serif": ["Computer Modern Roman"],
"text.latex.preamble": r"\usepackage{amsmath,amssymb,amsfonts,mathrsfs,bm}",
"axes.labelsize": 20,
"font.size": 20,
"legend.fontsize": 14,
"xtick.labelsize": 14,
"ytick.labelsize": 14,
"axes.titlesize": 22,
}
)
def load_summaries(results_dir: str) -> Dict[str, Dict[str, Any]]:
"""Load all summary.json files from subdirectories.
Args:
results_dir: Directory containing model evaluation results
Returns:
Dictionary mapping model names to their summary data
"""
summaries = {}
results_path = Path(results_dir)
if not results_path.exists():
logger.error(f"Results directory {results_dir} does not exist")
return {}
# Find all summary.json files
for model_dir in results_path.iterdir():
if not model_dir.is_dir():
continue
summary_path = model_dir / "summary.json"
if not summary_path.exists():
logger.warning(f"No summary.json found in {model_dir}")
continue
try:
# Extract model name from directory name (remove timestamp)
model_name = re.sub(r"_\d{8}_\d{6}$", "", model_dir.name)
# Replace underscores with slashes in model name for better display
model_name = model_name.replace("_", "/")
with open(summary_path, "r") as f:
summary_data = json.load(f)
# Check if summary has required fields
if "dataset_best_scores" not in summary_data:
logger.warning(f"Summary in {model_dir} is missing required fields")
continue
summaries[model_name] = summary_data
logger.info(f"Loaded summary for {model_name}")
except Exception as e:
logger.error(f"Error loading summary from {model_dir}: {str(e)}")
if not summaries:
logger.error("No valid summary files found")
return summaries
def get_dataset_categories(results_dir: str, summaries: Dict[str, Dict[str, Any]]) -> Dict[str, List[str]]:
"""Group datasets by their categories based on directory structure.
Args:
results_dir: Directory containing model evaluation results
summaries: Dictionary of model summaries
Returns:
Dictionary mapping category names to lists of dataset names
"""
categories = {}
results_path = Path(results_dir)
# Get all dataset names from the first summary
if not summaries:
return {}
first_summary = next(iter(summaries.values()))
all_datasets = set(first_summary["dataset_best_scores"].keys())
# Find categories by looking at directory structure
for model_dir in results_path.iterdir():
if not model_dir.is_dir():
continue
# Look for category directories
for category_dir in model_dir.iterdir():
if not category_dir.is_dir():
continue
category_name = category_dir.name
if category_name not in categories:
categories[category_name] = []
# Find all dataset JSON files in this category
for dataset_file in category_dir.glob("*.json"):
dataset_name = dataset_file.stem
if dataset_name in all_datasets and dataset_name not in categories[category_name]:
categories[category_name].append(dataset_name)
# Check if we found categories for all datasets
categorized_datasets = set()
for datasets in categories.values():
categorized_datasets.update(datasets)
uncategorized = all_datasets - categorized_datasets
if uncategorized:
logger.warning(f"Found {len(uncategorized)} datasets without categories")
categories["uncategorized"] = list(uncategorized)
return categories
def create_category_radar(summaries: Dict[str, Dict[str, Any]], categories: Dict[str, List[str]]) -> Figure:
"""Create a radar chart showing performance by category.
Args:
summaries: Dictionary of model summaries
categories: Dictionary mapping categories to dataset lists
Returns:
Matplotlib figure
"""
# Calculate average score per category for each model
category_scores = {}
for model_name, summary in summaries.items():
category_scores[model_name] = {}
for category, datasets in categories.items():
scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
if scores: # Avoid division by zero
category_scores[model_name][category] = np.mean(scores)
else:
category_scores[model_name][category] = 0
# Create radar chart
categories_list = sorted(categories.keys())
angles = np.linspace(0, 2 * np.pi, len(categories_list), endpoint=False).tolist()
angles += angles[:1] # Close the loop
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw=dict(polar=True))
# Use a color cycle for different models
colors = plt.cm.tab10.colors
for i, (model_name, scores) in enumerate(category_scores.items()):
color = colors[i % len(colors)]
values = [scores[cat] for cat in categories_list]
values += values[:1] # Close the loop
ax.plot(angles, values, linewidth=2, label=model_name, color=color)
ax.fill(angles, values, alpha=0.1, color=color)
# Set category labels
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories_list)
# Add radial grid lines at 0.2, 0.4, 0.6, 0.8
ax.set_rticks([0.2, 0.4, 0.6, 0.8])
ax.set_yticklabels(["0.2", "0.4", "0.6", "0.8"])
ax.set_rlabel_position(0) # Move radial labels away from plotted line
# Add legend and title
plt.legend(loc="upper right", bbox_to_anchor=(0.1, 0.1))
plt.title("Model Performance by Category", size=15)
return fig
def create_overall_performance_bar(summaries: Dict[str, Dict[str, Any]]) -> Figure:
"""Create a bar chart of overall model performance.
Args:
summaries: Dictionary of model summaries
Returns:
Matplotlib figure
"""
# Calculate overall average score for each model
overall_scores = {}
for model_name, summary in summaries.items():
scores = list(summary["dataset_best_scores"].values())
overall_scores[model_name] = np.mean(scores)
# Sort models by performance
sorted_models = sorted(overall_scores.items(), key=lambda x: x[1], reverse=True)
# Create bar chart
fig, ax = plt.subplots(figsize=(12, 6))
models = [m[0] for m in sorted_models]
scores = [m[1] for m in sorted_models]
# Use a color gradient based on performance
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(models)))
bars = ax.bar(models, scores, color=colors)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2%}", ha="center", va="bottom")
ax.set_ylabel("Average Score")
ax.set_ylim(0, max(scores) * 1.1) # Add some space for labels
plt.xticks(rotation=45, ha="right")
plt.title("Overall Model Performance", size=15)
plt.tight_layout()
return fig
def create_top_datasets_comparison(summaries: Dict[str, Dict[str, Any]], n: int = 15, mode: str = "hardest") -> Figure:
"""Create a bar chart comparing performance on top N datasets.
Args:
summaries: Dictionary of model summaries
n: Number of datasets to show
mode: Selection mode - 'hardest', 'easiest', or 'variable'
Returns:
Matplotlib figure
"""
if not summaries:
logger.error("No summaries provided")
return plt.figure()
# Calculate average score across all models for each dataset
dataset_avg_scores = {}
for dataset in next(iter(summaries.values()))["dataset_best_scores"].keys():
scores = [summary["dataset_best_scores"].get(dataset, 0) for summary in summaries.values()]
dataset_avg_scores[dataset] = np.mean(scores)
# Select top N datasets based on mode
if mode == "hardest":
# Select datasets with lowest average scores
selected_datasets = sorted(dataset_avg_scores.items(), key=lambda x: x[1])[:n]
elif mode == "easiest":
# Select datasets with highest average scores
selected_datasets = sorted(dataset_avg_scores.items(), key=lambda x: x[1], reverse=True)[:n]
else: # 'variable'
# Select datasets with highest variance in scores
dataset_variances = {}
for dataset in next(iter(summaries.values()))["dataset_best_scores"].keys():
scores = [summary["dataset_best_scores"].get(dataset, 0) for summary in summaries.values()]
dataset_variances[dataset] = np.var(scores)
selected_datasets = sorted(dataset_variances.items(), key=lambda x: x[1], reverse=True)[:n]
selected_datasets = [(dataset, dataset_avg_scores[dataset]) for dataset, _ in selected_datasets]
# Create horizontal bar chart
fig, ax = plt.subplots(figsize=(12, n * 0.5))
datasets = [d[0] for d in selected_datasets]
x = np.arange(len(datasets))
width = 0.8 / len(summaries)
# Use a color cycle for different models
colors = plt.cm.tab10.colors
for i, (model_name, summary) in enumerate(summaries.items()):
scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset, _ in selected_datasets]
ax.barh(
x + i * width - 0.4 + width / 2, scores, width, label=model_name, color=colors[i % len(colors)], alpha=0.8
)
ax.set_yticks(x)
ax.set_yticklabels(datasets)
ax.set_xlabel("Score")
ax.set_xlim(0, 1)
# Add legend and title
plt.legend(loc="upper right")
title = f'Model Performance on {n} {"Hardest" if mode=="hardest" else "Easiest" if mode=="easiest" else "Most Variable"} Datasets'
plt.title(title, size=15)
plt.tight_layout()
return fig
def create_performance_distribution_violin(summaries: Dict[str, Dict[str, Any]]) -> Figure:
"""Create a violin plot showing score distribution for each model.
Args:
summaries: Dictionary of model summaries
Returns:
Matplotlib figure
"""
# Prepare data for violin plot
data = []
labels = []
for model_name, summary in summaries.items():
scores = list(summary["dataset_best_scores"].values())
data.append(scores)
labels.append(model_name)
# Create violin plot
fig, ax = plt.subplots(figsize=(12, 6))
# Use a color cycle
colors = plt.cm.tab10.colors
parts = ax.violinplot(data, showmeans=True, showmedians=True)
# Customize violin plot
for i, pc in enumerate(parts["bodies"]):
pc.set_facecolor(colors[i % len(colors)])
pc.set_alpha(0.7)
# Add labels
ax.set_xticks(np.arange(1, len(labels) + 1))
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.set_ylabel("Score Distribution")
ax.set_ylim(0, 1)
# Add grid for better readability
ax.yaxis.grid(True)
# Add mean and median to legend
legend_elements = [
Patch(facecolor="black", edgecolor="black", label="Mean", alpha=0.3),
Patch(facecolor="white", edgecolor="black", label="Median"),
]
ax.legend(handles=legend_elements, loc="upper right")
plt.title("Distribution of Scores Across All Datasets", size=15)
plt.tight_layout()
return fig
def create_performance_heatmap(
summaries: Dict[str, Dict[str, Any]],
categories: Dict[str, List[str]],
) -> Figure:
"""
Heat-map of model performance (0100 %) across individual datasets.
Rows : models (sorted by overall mean score, high→low)
Cols : datasets grouped by `categories`
Cell : 100 × raw score (value shown inside each cell)
"""
if not summaries:
logger.error("No summaries provided")
return plt.figure()
# ---- gather dataset names in category order
all_datasets: List[str] = []
for cat, ds in sorted(categories.items()):
all_datasets.extend(sorted(ds))
# ---- sort models by overall performance
overall = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
models = [m for m, _ in sorted(overall.items(), key=lambda x: x[1], reverse=True)]
# ---- build score matrix (0100)
score_matrix = np.zeros((len(models), len(all_datasets)))
for i, model in enumerate(models):
for j, ds in enumerate(all_datasets):
score_matrix[i, j] = 100 * summaries[model]["dataset_best_scores"].get(ds, 0.0)
# ---- plot
fig, ax = plt.subplots(figsize=(max(20, len(all_datasets) * 0.25), max(8, len(models) * 0.5)))
im = ax.imshow(score_matrix, cmap="YlOrRd", aspect="auto", vmin=0, vmax=100)
# colour-bar
cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Score (\%)", rotation=-90, va="bottom")
# ticks & labels
ax.set_xticks(np.arange(len(all_datasets)))
ax.set_xticklabels(all_datasets, rotation=270, fontsize=8)
ax.set_yticks(np.arange(len(models)))
ax.set_yticklabels(models)
# category separators & titles
current = 0
label_offset = -0.25
for cat, ds in sorted(categories.items()):
if not ds:
continue
nxt = current + len(ds)
if nxt < len(all_datasets):
ax.axvline(nxt - 0.5, color="white", linewidth=2)
mid = current + len(ds) / 2 - 0.5
ax.text(
mid,
label_offset,
cat,
ha="center",
va="top",
fontsize=10,
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
current = nxt
# grid (mirrors comparison-plot style)
ax.set_xticks(np.arange(-0.5, len(all_datasets), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(models), 1), minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
# ---- annotate every cell with its value
for i in range(len(models)):
for j in range(len(all_datasets)):
val = score_matrix[i, j]
ax.text(
j,
i,
f"{val:.1f}",
ha="center",
va="center",
fontsize=7,
rotation=-90, # 90° clockwise
rotation_mode="anchor", # keep anchor point fixed
color="white" if val >= 50 else "black",
)
plt.tight_layout()
return fig
def create_dashboard(summaries: Dict[str, Dict[str, Any]], categories: Dict[str, List[str]]) -> Figure:
"""Create a comprehensive dashboard with multiple visualizations.
Args:
summaries: Dictionary of model summaries
categories: Dictionary mapping categories to dataset lists
Returns:
Matplotlib figure
"""
if not summaries:
logger.error("No summaries provided")
return plt.figure()
fig = plt.figure(figsize=(20, 15))
# 1. Overall performance comparison
ax1 = plt.subplot2grid((2, 2), (0, 0))
models = []
scores = []
for model_name, summary in summaries.items():
models.append(model_name)
scores.append(np.mean(list(summary["dataset_best_scores"].values())))
# Sort by performance
sorted_indices = np.argsort(scores)[::-1]
models = [models[i] for i in sorted_indices]
scores = [scores[i] for i in sorted_indices]
# Use a color gradient based on performance
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(models)))
bars = ax1.bar(models, scores, color=colors)
for bar in bars:
height = bar.get_height()
ax1.text(
bar.get_x() + bar.get_width() / 2.0, height + 0.01, f"{height:.2%}", ha="center", va="bottom", fontsize=8
)
ax1.set_ylabel("Average Score")
ax1.set_ylim(0, max(scores) * 1.1)
plt.setp(ax1.get_xticklabels(), rotation=45, ha="right", fontsize=8)
ax1.set_title("Overall Model Performance", size=12)
# 2. Top 10 hardest datasets comparison
ax2 = plt.subplot2grid((2, 2), (0, 1))
# Calculate average score across all models for each dataset
dataset_avg_scores = {}
for dataset in next(iter(summaries.values()))["dataset_best_scores"].keys():
scores = [summary["dataset_best_scores"].get(dataset, 0) for summary in summaries.values()]
dataset_avg_scores[dataset] = np.mean(scores)
# Select 10 hardest datasets
hardest_datasets = sorted(dataset_avg_scores.items(), key=lambda x: x[1])[:10]
datasets = [d[0] for d in hardest_datasets]
x = np.arange(len(datasets))
width = 0.8 / len(summaries)
# Use a color cycle for different models
colors = plt.cm.tab10.colors
for i, (model_name, summary) in enumerate(summaries.items()):
scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset, _ in hardest_datasets]
ax2.barh(
x + i * width - 0.4 + width / 2, scores, width, label=model_name, color=colors[i % len(colors)], alpha=0.8
)
ax2.set_yticks(x)
ax2.set_yticklabels(datasets, fontsize=8)
ax2.set_xlabel("Score")
ax2.set_xlim(0, 1)
ax2.set_title("Performance on 10 Hardest Datasets", size=12)
ax2.legend(fontsize=8)
# 3. Category radar chart
ax3 = plt.subplot2grid((2, 2), (1, 0), polar=True)
# Calculate average score per category for each model
category_scores = {}
for model_name, summary in summaries.items():
category_scores[model_name] = {}
for category, datasets in categories.items():
scores = [summary["dataset_best_scores"].get(dataset, 0) for dataset in datasets]
if scores: # Avoid division by zero
category_scores[model_name][category] = np.mean(scores)
else:
category_scores[model_name][category] = 0
# Create radar chart
categories_list = sorted(categories.keys())
angles = np.linspace(0, 2 * np.pi, len(categories_list), endpoint=False).tolist()
angles += angles[:1] # Close the loop
for i, (model_name, scores) in enumerate(category_scores.items()):
color = colors[i % len(colors)]
values = [scores.get(cat, 0) for cat in categories_list]
values += values[:1] # Close the loop
ax3.plot(angles, values, linewidth=2, label=model_name, color=color)
ax3.fill(angles, values, alpha=0.1, color=color)
ax3.set_xticks(angles[:-1])
ax3.set_xticklabels(categories_list, fontsize=8)
ax3.set_title("Performance by Category", size=12)
# 4. Performance distribution violin plot
ax4 = plt.subplot2grid((2, 2), (1, 1))
data = []
labels = []
for model_name, summary in summaries.items():
scores = list(summary["dataset_best_scores"].values())
data.append(scores)
labels.append(model_name)
parts = ax4.violinplot(data, showmeans=True, showmedians=True)
for i, pc in enumerate(parts["bodies"]):
pc.set_facecolor(colors[i % len(colors)])
pc.set_alpha(0.7)
ax4.set_xticks(np.arange(1, len(labels) + 1))
ax4.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
ax4.set_ylabel("Score Distribution")
ax4.set_ylim(0, 1)
ax4.yaxis.grid(True)
ax4.set_title("Distribution of Scores", size=12)
plt.tight_layout()
plt.suptitle("Model Evaluation Dashboard", size=16, y=0.98)
plt.subplots_adjust(top=0.9)
return fig
def create_comparison_plot(
summaries: Dict[str, Dict[str, Any]],
other_summaries: Dict[str, Dict[str, Any]],
categories: Optional[Dict[str, List[str]]] = None,
compare_model_ids: Optional[List[str]] = None,
) -> Figure:
"""
Build a heat-map of per-category score differences (scaled to 100 … 100).
Rows : category names (`categories`)
Cols : model IDs present in both `summaries` and `other_summaries`
Value : 100 × (mean(score in summaries) mean(score in other_summaries))
A numeric annotation (rounded to 2 dp) is rendered in every cell.
"""
if not summaries or not other_summaries:
logger.error("No summaries provided for comparison")
return plt.figure()
if categories is None:
all_ds = next(iter(summaries.values()))["dataset_best_scores"].keys()
categories = {"all": list(all_ds)}
# models present in both result sets
common_models = [m for m in summaries if m in other_summaries]
if not common_models:
logger.error("No overlapping model IDs between the two result sets.")
return plt.figure()
# sort models by overall performance
overall_scores = {m: np.mean(list(s["dataset_best_scores"].values())) for m, s in summaries.items()}
models = [m for m, _ in sorted(overall_scores.items(), key=lambda x: x[1], reverse=True) if m in common_models]
if compare_model_ids:
models = [m for m in models if m in compare_model_ids]
category_list = sorted(categories.keys())
# ---------- note the transposed shape (categories × models)
diff_matrix = np.zeros((len(category_list), len(models)))
# compute 100 × Δ
for i, cat in enumerate(category_list):
ds = categories[cat]
for j, model in enumerate(models):
cur_scores = summaries[model]["dataset_best_scores"]
base_scores = other_summaries[model]["dataset_best_scores"]
cur_mean = np.mean([cur_scores.get(d, 0.0) for d in ds]) if ds else 0.0
base_mean = np.mean([base_scores.get(d, 0.0) for d in ds]) if ds else 0.0
diff_matrix[i, j] = 100 * (cur_mean - base_mean)
# ---------------------------------------------------------------- plot
fig, ax = plt.subplots(figsize=(max(8, len(models) * 1.2), max(6, len(category_list) * 0.58)))
im = ax.imshow(diff_matrix, cmap="coolwarm", aspect="auto", vmin=-100, vmax=100)
# colour-bar
cbar = fig.colorbar(im, ax=ax)
cbar.ax.set_ylabel("$\Delta$ score (\%)", rotation=-90, va="bottom", fontweight="bold")
# ticks / labels
ax.set_xticks(np.arange(len(models)), labels=models, rotation=45, ha="right")
ax.set_yticks(np.arange(len(category_list)), labels=category_list)
# grid for readability
ax.set_xticks(np.arange(-0.5, len(models), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(category_list), 1), minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=0.5)
# annotate each cell
for i in range(len(category_list)):
for j in range(len(models)):
value = diff_matrix[i, j]
ax.text(
j,
i,
f"{value:.2f}",
ha="center",
va="center",
color="black" if abs(value) < 50 else "white",
fontsize=12,
)
# ax.set_title("Per-Category Performance $\Delta$ (hard easy)", fontweight="bold")
plt.tight_layout()
return fig
def save_figure(fig: Figure, output_dir: str, name: str, fmt: str = "png", dpi: int = 300) -> str:
"""Save a figure to a file.
Args:
fig: Matplotlib figure to save
output_dir: Directory to save the figure
name: Base name for the figure file
fmt: File format (png, pdf, svg)
dpi: DPI for raster formats
Returns:
Path to the saved file
"""
# Create filename
filename = f"{name.replace('/', '-')}.{fmt}"
filepath = output_dir / filename
# Save figure
fig.savefig(filepath, dpi=dpi, bbox_inches="tight")
logger.info(f"Saved {filepath}")
return filepath
def main():
"""Main function."""
parser = argparse.ArgumentParser(description="Generate visualizations from evaluation results")
parser.add_argument("--results-dir", required=True, help="Directory containing evaluation results")
parser.add_argument("--output-dir", default="visualizations", help="Directory to save visualizations")
parser.add_argument("--plots", default="all", help="Comma-separated list of plots to generate")
parser.add_argument("--top-n", type=int, default=15, help="Number of datasets to show in top datasets plot")
parser.add_argument(
"--top-mode", default="hardest", choices=["hardest", "easiest", "variable"], help="Mode for top datasets plot"
)
parser.add_argument("--compare-results-dir", help="Directory to compare results with", default=None)
parser.add_argument("--compare-model-ids", help="Comma-separated list of model IDs to compare", default=None)
parser.add_argument("--format", default="png", choices=["png", "pdf", "svg"], help="Output format for plots")
parser.add_argument("--dpi", type=int, default=300, help="DPI for output images")
parser.add_argument("--no-show", action="store_true", help="Don't display plots, just save them")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
args = parser.parse_args()
# Configure logging
if args.debug:
logger.setLevel(logging.DEBUG)
# Load summaries
logger.info(f"Loading summaries from {args.results_dir}")
summaries = load_summaries(args.results_dir)
args.output_dir = Path(args.output_dir)
if not args.output_dir.exists():
logger.info(f"Creating output directory {args.output_dir}")
args.output_dir.mkdir(parents=True, exist_ok=True)
if not summaries:
logger.error("No valid summaries found. Exiting.")
return 1
logger.info(f"Found {len(summaries)} model summaries")
# Get dataset categories
categories = get_dataset_categories(args.results_dir, summaries)
logger.info(f"Found {len(categories)} dataset categories")
# Determine which plots to generate
if args.plots.lower() == "all":
plots_to_generate = ["radar", "bar", "violin", "heatmap", "dashboard", "top_datasets", "compare"]
else:
plots_to_generate = [p.strip().lower() for p in args.plots.split(",")]
logger.info(f"Generating plots: {', '.join(plots_to_generate)}")
# Generate and save plots
for plot_type in plots_to_generate:
try:
if plot_type == "radar":
fig = create_category_radar(summaries, categories)
save_figure(fig, args.output_dir, "category_radar", args.format, args.dpi)
elif plot_type == "bar":
fig = create_overall_performance_bar(summaries)
save_figure(fig, args.output_dir, "overall_performance", args.format, args.dpi)
elif plot_type == "violin":
fig = create_performance_distribution_violin(summaries)
save_figure(fig, args.output_dir, "score_distribution", args.format, args.dpi)
elif plot_type == "heatmap":
fig = create_performance_heatmap(summaries, categories)
save_figure(fig, args.output_dir, "performance_heatmap", args.format, args.dpi)
elif plot_type == "dashboard":
fig = create_dashboard(summaries, categories)
save_figure(fig, args.output_dir, "evaluation_dashboard", args.format, args.dpi)
elif plot_type == "top_datasets":
fig = create_top_datasets_comparison(summaries, args.top_n, args.top_mode)
save_figure(fig, args.output_dir, f"top_{args.top_n}_{args.top_mode}_datasets", args.format, args.dpi)
elif plot_type == "compare":
assert args.compare_results_dir, "Comparison directory is required for compare plot"
other_summaries = load_summaries(args.compare_results_dir)
if not other_summaries:
logger.error("No valid summaries found in comparison directory. Exiting.")
return 1
compare_model_ids = args.compare_model_ids.split(",") if args.compare_model_ids else None
fig = create_comparison_plot(summaries, other_summaries, categories, compare_model_ids)
save_figure(fig, args.output_dir, "model_category_delta_heatmap", args.format, args.dpi)
else:
logger.warning(f"Unknown plot type: {plot_type}")
continue
# Show plot if requested
if not args.no_show:
plt.show()
else:
plt.close(fig)
except Exception as e:
logger.error(f"Error generating {plot_type} plot: {str(e)}")
if args.debug:
import traceback
traceback.print_exc()
logger.info(f"All visualizations saved to {args.output_dir}")
return 0
if __name__ == "__main__":
exit_code = main()
import sys
sys.exit(exit_code)