mirror of
https://github.com/zilliztech/claude-context.git
synced 2025-10-06 01:10:02 +03:00
Add Evaluation (#164)
* build evaluation framework Signed-off-by: ChengZi <chen.zhang@zilliz.com> * add Signed-off-by: ChengZi <chen.zhang@zilliz.com> * Update evaluation framework Signed-off-by: Cheney <chen.zhang@zilliz.com> --------- Signed-off-by: ChengZi <chen.zhang@zilliz.com> Signed-off-by: Cheney <chen.zhang@zilliz.com>
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -61,3 +61,8 @@ __pycache__/
|
||||
CLAUDE.md
|
||||
|
||||
.cursor/*
|
||||
|
||||
evaluation/repos
|
||||
repos
|
||||
|
||||
evaluation/retrieval_results*
|
||||
|
||||
10
README.md
10
README.md
@@ -574,6 +574,16 @@ Check the `/examples` directory for complete usage examples:
|
||||
|
||||
---
|
||||
|
||||
## 📊 Evaluation
|
||||
|
||||
Our controlled evaluation demonstrates that Claude Context MCP achieves ~40% token reduction under the condition of equivalent retrieval quality. This translates to significant cost and time savings in production environments. This also means that, under the constraint of limited token context length, using Claude Context yields better retrieval and answer results.
|
||||
|
||||

|
||||
|
||||
For detailed evaluation methodology and results, see the [evaluation directory](evaluation/).
|
||||
|
||||
---
|
||||
|
||||
## ❓ FAQ
|
||||
|
||||
**Common Questions:**
|
||||
|
||||
BIN
assets/mcp_efficiency_analysis_chart.png
Normal file
BIN
assets/mcp_efficiency_analysis_chart.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 170 KiB |
1
evaluation/.python-version
Normal file
1
evaluation/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
106
evaluation/README.md
Normal file
106
evaluation/README.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# Claude Context MCP Evaluation
|
||||
|
||||
This directory contains the evaluation framework and experimental results for comparing the efficiency of code retrieval using Claude Context MCP versus traditional grep-only approaches.
|
||||
|
||||
## Overview
|
||||
|
||||
We conducted a controlled experiment to measure the impact of adding Claude Context MCP tool to a baseline coding agent. The evaluation demonstrates significant improvements in token efficiency while maintaining comparable retrieval quality.
|
||||
|
||||
## Experimental Design
|
||||
|
||||
We designed a controlled experiment comparing two coding agents performing identical retrieval tasks. The baseline agent uses simple tools including read, grep, and edit functions. The enhanced agent adds Claude Context MCP tool to this same foundation. Both agents work on the same dataset using the same model to ensure fair comparison. We use [LangGraph MCP and ReAct framework](https://langchain-ai.github.io/langgraph/agents/mcp/#use-mcp-tools) to implement it.
|
||||
|
||||
We selected 30 instances from Princeton NLP's [SWE-bench_Verified](https://openai.com/index/introducing-swe-bench-verified/) dataset, filtering for 15-60 minute difficulty problems with exactly 2 file modifications. This subset represents typical coding tasks and enables quick validation. The dataset generation is implemented in [`generate_subset_json.py`](./generate_subset_json.py).
|
||||
|
||||
We chose [GPT-4o-mini](https://platform.openai.com/docs/models/gpt-4o-mini) as the default model for cost-effective considerations.
|
||||
|
||||
We ran each method 3 times independently, giving us 6 total runs for statistical reliability. We measured token usage, tool calls, retrieval precision, recall, and F1-score across all runs. The main entry point for running evaluations is [`run_evaluation.py`](./run_evaluation.py).
|
||||
|
||||
## Key Results
|
||||
|
||||
### Performance Summary
|
||||
|
||||
| Metric | Baseline (Grep Only) | With Claude Context MCP | Improvement |
|
||||
|--------|---------------------|--------------------------|-------------|
|
||||
| **Average F1-Score** | 0.40 | 0.40 | Comparable |
|
||||
| **Average Token Usage** | 73,373 | 44,449 | **-39.4%** |
|
||||
| **Average Tool Calls** | 8.3 | 5.3 | **-36.3%** |
|
||||
|
||||
### Key Findings
|
||||
|
||||
**Dramatic Efficiency Gains**:
|
||||
|
||||
With Claude Context MCP, we achieved:
|
||||
- **39.4% reduction** in token consumption (28,924 tokens saved per instance)
|
||||
- **36.3% reduction** in tool calls (3.0 fewer calls per instance)
|
||||
|
||||
|
||||
## Conclusion
|
||||
|
||||
The results demonstrate that Claude Context MCP provides:
|
||||
|
||||
### Immediate Benefits
|
||||
- **Cost Efficiency**: ~40% reduction in token usage directly reduces operational costs
|
||||
- **Speed Improvement**: Fewer tool calls and tokens mean faster code localization and task completion
|
||||
- **Better Quality**: This also means that, under the constraint of limited token context length, using Claude Context yields better retrieval and answer results.
|
||||
|
||||
### Strategic Advantages
|
||||
- **Better Resource Utilization**: Under fixed token budgets, Claude Context MCP enables handling more tasks
|
||||
- **Wider Usage Scenarios**: Lower per-task costs enable broader usage scenarios
|
||||
- **Improved User Experience**: Faster responses with maintained accuracy
|
||||
|
||||
|
||||
## Running the Evaluation
|
||||
|
||||
To reproduce these results:
|
||||
|
||||
1. **Install Dependencies**:
|
||||
|
||||
For python environment, you can use `uv` to install the lockfile dependencies.
|
||||
```bash
|
||||
cd evaluation && uv sync
|
||||
source .venv/bin/activate
|
||||
```
|
||||
For node environment, make sure your `node` version is `Node.js >= 20.0.0 and < 24.0.0`.
|
||||
|
||||
Our evaluation results are tested on `claude-context-mcp@0.1.0`, you can change the `claude-context` mcp server setting in the `retrieval/custom.py` file to get the latest version or use a development version.
|
||||
|
||||
2. **Set Environment Variables**:
|
||||
```bash
|
||||
export OPENAI_API_KEY=your_openai_api_key
|
||||
export MILVUS_ADDRESS=your_milvus_address
|
||||
```
|
||||
For more configuration details, refer the `claude-context` mcp server settings in the `retrieval/custom.py` file.
|
||||
|
||||
```bash
|
||||
export GITHUB_TOKEN=your_github_token
|
||||
```
|
||||
You need also prepare a `GITHUB_TOKEN` for automatically cloning the repositories, refer to [SWE-bench documentation](https://www.swebench.com/SWE-bench/guides/create_rag_datasets/#example-usage) for more details.
|
||||
|
||||
3. **Generate Dataset**:
|
||||
```bash
|
||||
python generate_subset_json.py
|
||||
```
|
||||
|
||||
4. **Run Baseline Evaluation**:
|
||||
```bash
|
||||
python run_evaluation.py --retrieval_types grep --output_dir retrieval_results_grep
|
||||
```
|
||||
|
||||
5. **Run Enhanced Evaluation**:
|
||||
```bash
|
||||
python run_evaluation.py --retrieval_types cc,grep --output_dir retrieval_results_both
|
||||
```
|
||||
|
||||
6. **Analyze Results**:
|
||||
```bash
|
||||
python analyze_and_plot_mcp_efficiency.py
|
||||
```
|
||||
|
||||
The evaluation framework is designed to be reproducible and can be easily extended to test additional configurations or datasets. Due to the proprietary nature of LLMs, exact numerical results may vary between runs and cannot be guaranteed to be identical. However, the core conclusions drawn from the analysis remain consistent and robust across different runs.
|
||||
|
||||
## Results Visualization
|
||||
|
||||

|
||||
|
||||
*The chart above shows the dramatic efficiency improvements achieved by Claude Context MCP while maintaining equivalent retrieval quality. Token usage and tool calls are significantly reduced with no loss in F1-score performance.*
|
||||
372
evaluation/analyze_and_plot_mcp_efficiency.py
Normal file
372
evaluation/analyze_and_plot_mcp_efficiency.py
Normal file
@@ -0,0 +1,372 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze retrieval results and create MCP efficiency chart using real data.
|
||||
This script loads data from the actual result directories and generates seaborn charts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def normalize_file_path(file_path: str) -> str:
|
||||
"""Normalize file paths."""
|
||||
if file_path.startswith("/"):
|
||||
file_path = file_path[1:]
|
||||
return file_path
|
||||
|
||||
|
||||
def calculate_metrics(hits: List[str], oracles: List[str]) -> Dict[str, float]:
|
||||
"""Calculate precision, recall, and F1-score."""
|
||||
if not hits and not oracles:
|
||||
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
||||
|
||||
# Normalize file paths
|
||||
hits_set = set(normalize_file_path(f) for f in hits)
|
||||
oracles_set = set(normalize_file_path(f) for f in oracles)
|
||||
|
||||
# Calculate intersection
|
||||
intersection = hits_set.intersection(oracles_set)
|
||||
|
||||
# Calculate metrics
|
||||
precision = len(intersection) / len(hits_set) if hits_set else 0.0
|
||||
recall = len(intersection) / len(oracles_set) if oracles_set else 0.0
|
||||
f1 = (
|
||||
2 * (precision * recall) / (precision + recall)
|
||||
if (precision + recall) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1": f1,
|
||||
"num_hits": len(hits_set),
|
||||
"num_oracles": len(oracles_set),
|
||||
"num_correct": len(intersection),
|
||||
}
|
||||
|
||||
|
||||
def load_method_results(method_dirs: List[str], method_name: str) -> Dict:
|
||||
"""Load and aggregate results from multiple runs of the same method."""
|
||||
|
||||
all_f1_scores = []
|
||||
all_token_usage = []
|
||||
all_tool_calls = []
|
||||
successful_instances = set()
|
||||
|
||||
print(f"\nLoading {method_name} method data from {len(method_dirs)} runs...")
|
||||
|
||||
for run_idx, run_dir in enumerate(method_dirs):
|
||||
print(f" Processing run {run_idx + 1}: {run_dir}")
|
||||
|
||||
if not os.path.exists(run_dir):
|
||||
print(f" Warning: Directory {run_dir} does not exist")
|
||||
continue
|
||||
|
||||
run_success_count = 0
|
||||
run_f1_scores = []
|
||||
run_tokens = []
|
||||
run_tools = []
|
||||
|
||||
for item in os.listdir(run_dir):
|
||||
instance_dir = os.path.join(run_dir, item)
|
||||
result_file = os.path.join(instance_dir, "result.json")
|
||||
|
||||
if os.path.isdir(instance_dir) and os.path.exists(result_file):
|
||||
try:
|
||||
with open(result_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Calculate F1-score
|
||||
hits = data.get("hits", [])
|
||||
oracles = data.get("oracles", [])
|
||||
metrics = calculate_metrics(hits, oracles)
|
||||
|
||||
# Extract other metrics
|
||||
tokens = data.get("token_usage", {}).get("total_tokens", 0)
|
||||
tools = data.get("tool_stats", {}).get("total_tool_calls", 0)
|
||||
|
||||
# Store data
|
||||
run_f1_scores.append(metrics["f1"])
|
||||
run_tokens.append(tokens)
|
||||
run_tools.append(tools)
|
||||
|
||||
successful_instances.add(item)
|
||||
run_success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to load {result_file}: {e}")
|
||||
continue
|
||||
|
||||
print(f" Loaded {run_success_count} successful instances")
|
||||
|
||||
# Add this run's data to overall collection
|
||||
all_f1_scores.extend(run_f1_scores)
|
||||
all_token_usage.extend(run_tokens)
|
||||
all_tool_calls.extend(run_tools)
|
||||
|
||||
# Calculate aggregated statistics
|
||||
results = {
|
||||
"method_name": method_name,
|
||||
"total_runs": len(method_dirs),
|
||||
"successful_instances": len(successful_instances),
|
||||
"avg_f1": np.mean(all_f1_scores) if all_f1_scores else 0,
|
||||
"std_f1": np.std(all_f1_scores) if all_f1_scores else 0,
|
||||
"avg_tokens": np.mean(all_token_usage) if all_token_usage else 0,
|
||||
"std_tokens": np.std(all_token_usage) if all_token_usage else 0,
|
||||
"avg_tools": np.mean(all_tool_calls) if all_tool_calls else 0,
|
||||
"std_tools": np.std(all_tool_calls) if all_tool_calls else 0,
|
||||
}
|
||||
|
||||
print(f" Aggregated results:")
|
||||
print(f" Avg F1-Score: {results['avg_f1']:.3f} ± {results['std_f1']:.3f}")
|
||||
print(f" Avg Tokens: {results['avg_tokens']:.0f} ± {results['std_tokens']:.0f}")
|
||||
print(
|
||||
f" Avg Tool Calls: {results['avg_tools']:.1f} ± {results['std_tools']:.1f}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_efficiency_chart(both_results: Dict, grep_results: Dict):
|
||||
"""Create the efficiency comparison chart using Seaborn."""
|
||||
|
||||
# Set the aesthetic style
|
||||
sns.set_style("whitegrid")
|
||||
sns.set_palette("husl")
|
||||
|
||||
# Prepare data for plotting
|
||||
data = {
|
||||
"Method": [
|
||||
"With claude-context MCP",
|
||||
"Baseline",
|
||||
"With claude-context MCP",
|
||||
"Baseline",
|
||||
],
|
||||
"Metric": ["Token Usage", "Token Usage", "Tool Calls", "Tool Calls"],
|
||||
"Value": [
|
||||
both_results["avg_tokens"] / 1000, # Convert to thousands
|
||||
grep_results["avg_tokens"] / 1000,
|
||||
both_results["avg_tools"],
|
||||
grep_results["avg_tools"],
|
||||
],
|
||||
}
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Create figure with custom styling
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))
|
||||
|
||||
# Custom color palette
|
||||
colors = ["#3498db", "#e74c3c"] # Modern blue and red
|
||||
|
||||
# Token Usage subplot
|
||||
token_data = df[df["Metric"] == "Token Usage"]
|
||||
sns.barplot(
|
||||
data=token_data,
|
||||
x="Method",
|
||||
y="Value",
|
||||
ax=ax1,
|
||||
palette=colors,
|
||||
alpha=0.8,
|
||||
edgecolor="white",
|
||||
linewidth=2,
|
||||
)
|
||||
|
||||
ax1.set_title("Token Usage", fontsize=18, fontweight="bold", pad=20)
|
||||
ax1.set_ylabel("Average Tokens (K)", fontsize=14, fontweight="bold")
|
||||
ax1.set_xlabel("")
|
||||
ax1.tick_params(axis="x", labelsize=12)
|
||||
ax1.tick_params(axis="y", labelsize=12)
|
||||
# Set y-axis range with some padding
|
||||
ax1.set_ylim(0, max(token_data["Value"]) * 1.15)
|
||||
|
||||
# Add value labels for token usage
|
||||
token_values = [
|
||||
both_results["avg_tokens"] / 1000,
|
||||
grep_results["avg_tokens"] / 1000,
|
||||
]
|
||||
for i, val in enumerate(token_values):
|
||||
ax1.text(
|
||||
i,
|
||||
val + 2,
|
||||
f"{val:.1f}K",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=13,
|
||||
color=colors[i],
|
||||
)
|
||||
|
||||
# Add improvement annotation for tokens
|
||||
token_reduction = (
|
||||
(grep_results["avg_tokens"] - both_results["avg_tokens"])
|
||||
/ grep_results["avg_tokens"]
|
||||
* 100
|
||||
)
|
||||
mid_height = max(token_values) * 0.8
|
||||
ax1.annotate(
|
||||
f"-{token_reduction:.1f}%",
|
||||
xy=(0.5, mid_height),
|
||||
xycoords="data",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=16,
|
||||
fontweight="bold",
|
||||
bbox=dict(
|
||||
boxstyle="round,pad=0.5",
|
||||
facecolor="#2ecc71",
|
||||
alpha=0.8,
|
||||
edgecolor="white",
|
||||
linewidth=2,
|
||||
),
|
||||
color="white",
|
||||
)
|
||||
|
||||
# Tool Calls subplot
|
||||
tool_data = df[df["Metric"] == "Tool Calls"]
|
||||
sns.barplot(
|
||||
data=tool_data,
|
||||
x="Method",
|
||||
y="Value",
|
||||
ax=ax2,
|
||||
palette=colors,
|
||||
alpha=0.8,
|
||||
edgecolor="white",
|
||||
linewidth=2,
|
||||
)
|
||||
|
||||
ax2.set_title("Tool Calls", fontsize=18, fontweight="bold", pad=20)
|
||||
ax2.set_ylabel("Average Number of Calls", fontsize=14, fontweight="bold")
|
||||
ax2.set_xlabel("")
|
||||
ax2.tick_params(axis="x", labelsize=12)
|
||||
ax2.tick_params(axis="y", labelsize=12)
|
||||
# Set y-axis range with some padding
|
||||
ax2.set_ylim(0, max(tool_data["Value"]) * 1.15)
|
||||
|
||||
# Add value labels for tool calls
|
||||
tool_values = [both_results["avg_tools"], grep_results["avg_tools"]]
|
||||
for i, val in enumerate(tool_values):
|
||||
ax2.text(
|
||||
i,
|
||||
val + 0.2,
|
||||
f"{val:.1f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=13,
|
||||
color=colors[i],
|
||||
)
|
||||
|
||||
# Add improvement annotation for tool calls
|
||||
tool_reduction = (
|
||||
(grep_results["avg_tools"] - both_results["avg_tools"])
|
||||
/ grep_results["avg_tools"]
|
||||
* 100
|
||||
)
|
||||
mid_height = max(tool_values) * 0.8
|
||||
ax2.annotate(
|
||||
f"-{tool_reduction:.1f}%",
|
||||
xy=(0.5, mid_height),
|
||||
xycoords="data",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=16,
|
||||
fontweight="bold",
|
||||
bbox=dict(
|
||||
boxstyle="round,pad=0.5",
|
||||
facecolor="#2ecc71",
|
||||
alpha=0.8,
|
||||
edgecolor="white",
|
||||
linewidth=2,
|
||||
),
|
||||
color="white",
|
||||
)
|
||||
|
||||
# Keep x-axis labels horizontal and add grid
|
||||
for ax in [ax1, ax2]:
|
||||
ax.tick_params(axis="x", rotation=0)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Adjust layout
|
||||
plt.tight_layout()
|
||||
|
||||
# Save with high quality
|
||||
output_file = "mcp_efficiency_analysis_chart.png"
|
||||
plt.savefig(
|
||||
output_file, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
|
||||
)
|
||||
plt.show()
|
||||
|
||||
print(f"\nChart saved as: {output_file}")
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*80}")
|
||||
print(f"MCP EFFICIENCY ANALYSIS SUMMARY")
|
||||
print(f"{'='*80}")
|
||||
print(f"Method Comparison:")
|
||||
print(f" Both (MCP) vs Grep (Baseline)")
|
||||
print(
|
||||
f" Runs per method: {both_results['total_runs']} vs {grep_results['total_runs']}"
|
||||
)
|
||||
|
||||
print(f"\nF1-Score Comparison:")
|
||||
print(f" Both Method: {both_results['avg_f1']:.3f} ± {both_results['std_f1']:.3f}")
|
||||
print(f" Grep Method: {grep_results['avg_f1']:.3f} ± {grep_results['std_f1']:.3f}")
|
||||
f1_change = (
|
||||
(
|
||||
(both_results["avg_f1"] - grep_results["avg_f1"])
|
||||
/ grep_results["avg_f1"]
|
||||
* 100
|
||||
)
|
||||
if grep_results["avg_f1"] > 0
|
||||
else 0
|
||||
)
|
||||
print(f" F1-Score change: {f1_change:+.1f}%")
|
||||
|
||||
print(f"\nEfficiency Improvements:")
|
||||
print(
|
||||
f" Token usage reduction: {token_reduction:.1f}% (from {grep_results['avg_tokens']:.0f} to {both_results['avg_tokens']:.0f})"
|
||||
)
|
||||
print(
|
||||
f" Tool calls reduction: {tool_reduction:.1f}% (from {grep_results['avg_tools']:.1f} to {both_results['avg_tools']:.1f})"
|
||||
)
|
||||
print(
|
||||
f" Average token savings: {grep_results['avg_tokens'] - both_results['avg_tokens']:.0f} tokens per instance"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to analyze and plot MCP efficiency."""
|
||||
|
||||
print("MCP Efficiency Analysis - Loading Data")
|
||||
print("=" * 60)
|
||||
|
||||
# Define directories for each method
|
||||
both_dirs = [
|
||||
"retrieval_results_both",
|
||||
"retrieval_results_both2",
|
||||
"retrieval_results_both3",
|
||||
]
|
||||
|
||||
grep_dirs = [
|
||||
"retrieval_results_grep",
|
||||
"retrieval_results_grep2",
|
||||
"retrieval_results_grep3",
|
||||
]
|
||||
|
||||
# Load and analyze results
|
||||
both_results = load_method_results(both_dirs, "Both (with claude-context MCP)")
|
||||
grep_results = load_method_results(grep_dirs, "Grep (baseline)")
|
||||
|
||||
# Create the efficiency chart
|
||||
create_efficiency_chart(both_results, grep_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
62
evaluation/client.py
Normal file
62
evaluation/client.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from utils.format import (
|
||||
extract_conversation_summary,
|
||||
extract_file_paths_from_edits,
|
||||
calculate_total_tokens,
|
||||
)
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""Evaluator class for running LLM queries with MCP tools"""
|
||||
|
||||
def __init__(self, llm_model, tools):
|
||||
"""
|
||||
Initialize the Evaluator
|
||||
|
||||
Args:
|
||||
llm_model: LangChain LLM model instance (required)
|
||||
tools: List of tools to use (required)
|
||||
"""
|
||||
self.llm_model = llm_model
|
||||
self.tools = tools
|
||||
self.agent = create_react_agent(self.llm_model, self.tools)
|
||||
|
||||
# Setup event loop for sync usage
|
||||
try:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
async def async_run(self, query, codebase_path=None):
|
||||
"""Internal async method to run the query"""
|
||||
response = await self.agent.ainvoke(
|
||||
{"messages": [{"role": "user", "content": query}]},
|
||||
config={"recursion_limit": 150},
|
||||
)
|
||||
|
||||
# Extract data without printing
|
||||
conversation_summary, tool_stats = extract_conversation_summary(response)
|
||||
token_usage = calculate_total_tokens(response)
|
||||
|
||||
if codebase_path:
|
||||
file_paths = extract_file_paths_from_edits(response, codebase_path)
|
||||
else:
|
||||
file_paths = []
|
||||
|
||||
return conversation_summary, token_usage, file_paths, tool_stats
|
||||
|
||||
def run(self, query: str, codebase_path=None):
|
||||
"""
|
||||
Run a query synchronously
|
||||
|
||||
Args:
|
||||
query (str): The query to execute
|
||||
codebase_path (str): Path to the codebase for relative path conversion
|
||||
|
||||
Returns:
|
||||
tuple: (response, conversation_summary, token_usage, file_paths)
|
||||
"""
|
||||
|
||||
return asyncio.run(self.async_run(query, codebase_path))
|
||||
98
evaluation/generate_subset_json.py
Normal file
98
evaluation/generate_subset_json.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate swe_verified_15min1h_2files_instances.json from the subset analysis
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
|
||||
def parse_patch_files(patch_content):
|
||||
"""Parse patch content to extract the number of modified files"""
|
||||
if not patch_content:
|
||||
return []
|
||||
|
||||
file_pattern = r'^diff --git a/(.*?) b/(.*?)$'
|
||||
files = []
|
||||
|
||||
for line in patch_content.split('\n'):
|
||||
match = re.match(file_pattern, line)
|
||||
if match:
|
||||
file_path = match.group(1)
|
||||
files.append(file_path)
|
||||
|
||||
return files
|
||||
|
||||
def main():
|
||||
print("Loading SWE-bench_Verified dataset...")
|
||||
dataset = load_dataset("princeton-nlp/SWE-bench_Verified")
|
||||
instances = list(dataset['test'])
|
||||
|
||||
print("Filtering instances for: 15min-1hour difficulty + 2 patch files...")
|
||||
|
||||
# Filter for the specific subset
|
||||
subset_instances = []
|
||||
|
||||
for instance in instances:
|
||||
difficulty = instance.get('difficulty', 'Unknown')
|
||||
|
||||
# Parse main patch to count files
|
||||
patch_content = instance.get('patch', '')
|
||||
patch_files = parse_patch_files(patch_content)
|
||||
oracle_count = len(patch_files)
|
||||
|
||||
# Check if it matches our criteria
|
||||
if difficulty == '15 min - 1 hour' and oracle_count == 2:
|
||||
subset_instances.append(instance)
|
||||
|
||||
print(f"Found {len(subset_instances)} instances matching criteria")
|
||||
|
||||
# Create the JSON structure that _prepare_instances expects
|
||||
output_data = {
|
||||
"metadata": {
|
||||
"description": "SWE-bench_Verified subset: 15min-1hour difficulty with 2 patch files",
|
||||
"source_dataset": "princeton-nlp/SWE-bench_Verified",
|
||||
"extraction_date": "2024",
|
||||
"filter_criteria": {
|
||||
"difficulty": "15 min - 1 hour",
|
||||
"patch_files_count": 2
|
||||
},
|
||||
"total_instances": len(subset_instances),
|
||||
"statistics": {
|
||||
"total_instances_in_original": 500,
|
||||
"subset_count": len(subset_instances),
|
||||
"percentage_of_original": round((len(subset_instances) / 500) * 100, 1)
|
||||
}
|
||||
},
|
||||
"instances": subset_instances
|
||||
}
|
||||
|
||||
# Save to JSON file
|
||||
output_file = "swe_verified_15min1h_2files_instances.json"
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
print(f"Generated {output_file} with {len(subset_instances)} instances")
|
||||
|
||||
# Verify the structure
|
||||
print("\nVerifying JSON structure...")
|
||||
with open(output_file, 'r') as f:
|
||||
loaded_data = json.load(f)
|
||||
|
||||
print(f"✓ Contains 'instances' key: {'instances' in loaded_data}")
|
||||
print(f"✓ Contains 'metadata' key: {'metadata' in loaded_data}")
|
||||
print(f"✓ Number of instances: {len(loaded_data['instances'])}")
|
||||
print(f"✓ First instance has required fields:")
|
||||
|
||||
if loaded_data['instances']:
|
||||
first_instance = loaded_data['instances'][0]
|
||||
required_fields = ['instance_id', 'repo', 'base_commit', 'problem_statement']
|
||||
for field in required_fields:
|
||||
has_field = field in first_instance
|
||||
print(f" - {field}: {'✓' if has_field else '✗'}")
|
||||
|
||||
print(f"\nFile successfully generated: {output_file}")
|
||||
print("This file can be used with BaseRetrieval._prepare_instances()")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
24
evaluation/pyproject.toml
Normal file
24
evaluation/pyproject.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[project]
|
||||
name = "evaluation"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"langchain-mcp-adapters>=0.1.9",
|
||||
"langgraph>=0.6.4",
|
||||
"mcp>=1.12.4",
|
||||
"langchain>=0.1.0",
|
||||
"langchain-core>=0.1.0",
|
||||
"langchain-anthropic>=0.1.0",
|
||||
"langchain-openai>=0.3.29",
|
||||
"langchain-ollama>=0.3.6",
|
||||
"datasets>=4.0.0",
|
||||
"gitpython>=3.1.45",
|
||||
"matplotlib>=3.10.5",
|
||||
"seaborn>=0.13.2",
|
||||
"pandas>=2.3.1",
|
||||
"numpy>=2.2.6",
|
||||
"plotly>=6.3.0",
|
||||
"kaleido>=1.0.0",
|
||||
]
|
||||
0
evaluation/retrieval/__init__.py
Normal file
0
evaluation/retrieval/__init__.py
Normal file
210
evaluation/retrieval/base.py
Normal file
210
evaluation/retrieval/base.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
from datasets import load_from_disk, load_dataset
|
||||
|
||||
from utils.file_management import get_remaining_instances
|
||||
|
||||
|
||||
from utils.file_management import ContextManager, clone_repo
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRetrieval:
|
||||
def __init__(
|
||||
self, *, dataset_name_or_path, splits, output_dir, max_instances=None, **kwargs
|
||||
):
|
||||
self.dataset_name_or_path = dataset_name_or_path
|
||||
self.splits = splits
|
||||
self.output_dir = output_dir
|
||||
self.max_instances = max_instances
|
||||
self.instances = self._prepare_instances()
|
||||
self.prompt = """The codebase is at {repo_path}.
|
||||
|
||||
Issue:
|
||||
<issue>
|
||||
{issue}
|
||||
</issue>
|
||||
|
||||
Your task is to identify and edit the files that need to be modified to resolve the issue.
|
||||
Focus on making the necessary changes to completely address the problem.
|
||||
Use the available tools step by step to accomplish this goal. The primary objective is to edit the existing code files. No validation or testing is required.
|
||||
"""
|
||||
|
||||
def _prepare_instances(self) -> List[Dict]:
|
||||
if Path(self.dataset_name_or_path).exists():
|
||||
# Check if it's a JSON file
|
||||
if self.dataset_name_or_path.endswith(".json"):
|
||||
with open(self.dataset_name_or_path, "r") as f:
|
||||
data = json.load(f)
|
||||
# If it's our custom JSON format with instances data
|
||||
if "instances" in data:
|
||||
logger.info(
|
||||
f"Loaded {len(data['instances'])} instances from JSON file"
|
||||
)
|
||||
if "metadata" in data and "statistics" in data["metadata"]:
|
||||
logger.info(f"Statistics: {data['metadata']['statistics']}")
|
||||
# Create a simple dict that mimics HuggingFace dataset structure
|
||||
dataset = {"test": data["instances"]}
|
||||
elif "test" in data:
|
||||
dataset = {"test": data["test"]}
|
||||
else:
|
||||
# Assume the JSON file itself contains the instances
|
||||
dataset = {"test": data if isinstance(data, list) else [data]}
|
||||
dataset_name = os.path.basename(self.dataset_name_or_path).replace(
|
||||
".json", ""
|
||||
)
|
||||
else:
|
||||
dataset = load_from_disk(self.dataset_name_or_path)
|
||||
dataset_name = os.path.basename(self.dataset_name_or_path)
|
||||
else:
|
||||
dataset = load_dataset(self.dataset_name_or_path)
|
||||
dataset_name = self.dataset_name_or_path.replace("/", "__")
|
||||
|
||||
instances = []
|
||||
from datasets import DatasetDict
|
||||
|
||||
if isinstance(dataset, DatasetDict):
|
||||
available_splits = set(dataset.keys())
|
||||
if set(self.splits) - available_splits != set():
|
||||
missing_splits = set(self.splits) - available_splits
|
||||
logger.warning(f"Unknown splits {missing_splits}")
|
||||
|
||||
for split in self.splits:
|
||||
logger.info(f"Loading split '{split}'")
|
||||
from datasets import DatasetDict, IterableDatasetDict
|
||||
|
||||
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
|
||||
split_instances = list(dataset[split])
|
||||
elif isinstance(dataset, dict) and split in dataset:
|
||||
# Handle our custom JSON format
|
||||
split_instances = dataset[split]
|
||||
else:
|
||||
split_instances = list(dataset)
|
||||
instances.extend(split_instances)
|
||||
logger.info(f"Loaded {len(split_instances)} instances from split '{split}'")
|
||||
|
||||
output_file = Path(self.output_dir) / f"{dataset_name}__retrieval.jsonl"
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check for both JSONL format (for legacy compatibility) and directory structure format
|
||||
remaining_instances, processed_count = self._filter_existing_instances(
|
||||
instances, output_file
|
||||
)
|
||||
|
||||
if not remaining_instances:
|
||||
logger.info("All instances already processed")
|
||||
return []
|
||||
|
||||
# Apply max_instances limit if specified
|
||||
if self.max_instances is not None and self.max_instances > 0:
|
||||
# Check if we've already processed enough instances
|
||||
if processed_count >= self.max_instances:
|
||||
logger.info(
|
||||
f"Already processed {processed_count} instances, which meets or exceeds max_instances={self.max_instances}. No more instances to process."
|
||||
)
|
||||
return []
|
||||
|
||||
# Calculate how many more instances we need to process
|
||||
remaining_needed = self.max_instances - processed_count
|
||||
if len(remaining_instances) > remaining_needed:
|
||||
logger.info(
|
||||
f"Limiting to {remaining_needed} more instances (processed: {processed_count}, target: {self.max_instances}, remaining: {len(remaining_instances)})"
|
||||
)
|
||||
remaining_instances = remaining_instances[:remaining_needed]
|
||||
|
||||
return remaining_instances
|
||||
|
||||
def _filter_existing_instances(
|
||||
self, instances: List[Dict], output_file: Path
|
||||
) -> Tuple[List[Dict], int]:
|
||||
"""
|
||||
Filter instances to exclude those that have already been processed.
|
||||
|
||||
This method supports both output formats:
|
||||
1. JSONL format (legacy): results saved to a single JSONL file
|
||||
2. Directory format: results saved to individual directories with result.json files
|
||||
|
||||
Args:
|
||||
instances: List of instances to filter
|
||||
output_file: Path to the JSONL output file (used for legacy format detection)
|
||||
|
||||
Returns:
|
||||
Tuple of (remaining_instances, processed_count)
|
||||
"""
|
||||
# First check JSONL format for backward compatibility
|
||||
if output_file.exists():
|
||||
# JSONL format already handled by get_remaining_instances
|
||||
remaining_instances = get_remaining_instances(instances, output_file)
|
||||
processed_count = len(instances) - len(remaining_instances)
|
||||
return remaining_instances, processed_count
|
||||
else:
|
||||
# Check directory structure format
|
||||
processed_instance_ids = set()
|
||||
|
||||
# Check if output directory exists and has subdirectories with result.json
|
||||
if os.path.exists(self.output_dir):
|
||||
for item in os.listdir(self.output_dir):
|
||||
instance_dir = os.path.join(self.output_dir, item)
|
||||
result_file = os.path.join(instance_dir, "result.json")
|
||||
if os.path.isdir(instance_dir) and os.path.exists(result_file):
|
||||
processed_instance_ids.add(item)
|
||||
|
||||
processed_count = len(processed_instance_ids)
|
||||
if processed_count > 0:
|
||||
logger.info(
|
||||
f"Found {processed_count} existing instances in directory format. Will skip them."
|
||||
)
|
||||
|
||||
# Filter out already processed instances
|
||||
remaining_instances = [
|
||||
instance
|
||||
for instance in instances
|
||||
if instance["instance_id"] not in processed_instance_ids
|
||||
]
|
||||
|
||||
return remaining_instances, processed_count
|
||||
|
||||
def build_index(self, repo_path: str) -> Any:
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
def search(self, repo_path: str, issue: str, k: int = 20) -> List[Dict[str, Any]]:
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
def run(self, root_dir: str, token: str = "git") -> None:
|
||||
for instance in tqdm(self.instances, desc="Running retrieval"):
|
||||
instance_id = instance["instance_id"]
|
||||
repo = instance["repo"]
|
||||
commit = instance["base_commit"]
|
||||
issue = instance["problem_statement"]
|
||||
|
||||
try:
|
||||
repo_dir = clone_repo(repo, root_dir, token)
|
||||
|
||||
with ContextManager(str(repo_dir), commit):
|
||||
logger.info(f"Building index for {instance_id}")
|
||||
self.build_index(str(repo_dir))
|
||||
|
||||
logger.info(f"Searching for {instance_id}")
|
||||
hits = self.search(repo_dir, issue, k=20)
|
||||
|
||||
result = {"instance_id": instance_id, "hits": hits}
|
||||
|
||||
with open(self.output_file, "a") as f:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
logger.info(
|
||||
f"Retrieval completed. Results saved to {self.output_file}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {instance_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
continue
|
||||
399
evaluation/retrieval/custom.py
Normal file
399
evaluation/retrieval/custom.py
Normal file
@@ -0,0 +1,399 @@
|
||||
import traceback
|
||||
from typing import List, Dict, Any
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from retrieval.base import BaseRetrieval
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from client import Evaluator
|
||||
from utils.llm_factory import llm_factory
|
||||
from utils.constant import project_path, evaluation_path
|
||||
from utils.format import extract_oracle_files_from_patch, create_unified_diff_file
|
||||
import json
|
||||
import traceback
|
||||
from tqdm.auto import tqdm
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from utils.file_management import ContextManager, clone_repo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomRetrieval(BaseRetrieval):
|
||||
def __init__(
|
||||
self,
|
||||
llm_type: str,
|
||||
llm_model: str,
|
||||
retrieval_types: List[str],
|
||||
*,
|
||||
dataset_name_or_path,
|
||||
splits,
|
||||
output_dir,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize CustomRetrieval with specified retrieval types.
|
||||
|
||||
Args:
|
||||
llm_type: Type of LLM to use
|
||||
llm_model: LLM model name
|
||||
retrieval_types: List containing "cc", "grep", or both
|
||||
dataset_name_or_path: Dataset path
|
||||
splits: Dataset splits
|
||||
output_dir: Output directory
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_name_or_path=dataset_name_or_path,
|
||||
splits=splits,
|
||||
output_dir=output_dir,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Validate retrieval types
|
||||
valid_types = {"cc", "grep"}
|
||||
if not isinstance(retrieval_types, list):
|
||||
raise ValueError("retrieval_types must be a list")
|
||||
if not all(rt in valid_types for rt in retrieval_types):
|
||||
raise ValueError(
|
||||
f"retrieval_types must contain only 'cc' and/or 'grep', got: {retrieval_types}"
|
||||
)
|
||||
if not retrieval_types:
|
||||
raise ValueError("retrieval_types cannot be empty")
|
||||
|
||||
self.retrieval_types = retrieval_types
|
||||
self.llm_model = llm_factory(llm_type, llm_model)
|
||||
self.mcp_client = self._create_mcp_client()
|
||||
|
||||
def _create_mcp_client(self) -> MultiServerMCPClient:
|
||||
"""Create MCP client based on retrieval types"""
|
||||
servers = {
|
||||
"filesystem": {
|
||||
"command": sys.executable,
|
||||
"args": [str(evaluation_path / "servers/read_server.py"),],
|
||||
"transport": "stdio",
|
||||
},
|
||||
"edit": {
|
||||
"command": sys.executable,
|
||||
"args": [str(evaluation_path / "servers/edit_server.py"),],
|
||||
"transport": "stdio",
|
||||
},
|
||||
}
|
||||
|
||||
# Add CC server if needed
|
||||
if "cc" in self.retrieval_types:
|
||||
servers["claude-context"] = {
|
||||
# "command": "node",
|
||||
# "args": [str(project_path / "packages/mcp/dist/index.js")], # For development environment
|
||||
"command": "npx",
|
||||
"args": ["-y", "@zilliz/claude-context-mcp@0.1.0"], # For reproduction environment
|
||||
"env": {
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||
"MILVUS_ADDRESS": os.getenv("MILVUS_ADDRESS"),
|
||||
"EMBEDDING_BATCH_SIZE": os.getenv("EMBEDDING_BATCH_SIZE", "100"),
|
||||
},
|
||||
"transport": "stdio",
|
||||
}
|
||||
|
||||
# Add Grep server if needed
|
||||
if "grep" in self.retrieval_types:
|
||||
servers["grep"] = {
|
||||
"command": sys.executable,
|
||||
"args": [str(evaluation_path / "servers/grep_server.py"),],
|
||||
"transport": "stdio",
|
||||
}
|
||||
|
||||
return MultiServerMCPClient(servers)
|
||||
|
||||
@asynccontextmanager
|
||||
async def mcp_sessions_context(self):
|
||||
"""Context manager for MCP sessions and tools loading"""
|
||||
# Build session context based on retrieval types
|
||||
session_names = ["filesystem", "edit"]
|
||||
|
||||
# Add CC session if needed
|
||||
if "cc" in self.retrieval_types:
|
||||
session_names.append("claude-context")
|
||||
|
||||
# Add Grep session if needed
|
||||
if "grep" in self.retrieval_types:
|
||||
session_names.append("grep")
|
||||
|
||||
# Create the appropriate context manager based on which sessions we need
|
||||
if len(session_names) == 2: # filesystem + edit
|
||||
async with self.mcp_client.session(
|
||||
"filesystem"
|
||||
) as fs_session, self.mcp_client.session("edit") as edit_session:
|
||||
sessions = {
|
||||
"filesystem": fs_session,
|
||||
"edit": edit_session,
|
||||
}
|
||||
yield await self._load_tools_from_sessions(sessions)
|
||||
elif len(session_names) == 3:
|
||||
if "claude-context" in session_names:
|
||||
async with self.mcp_client.session(
|
||||
"filesystem"
|
||||
) as fs_session, self.mcp_client.session(
|
||||
"edit"
|
||||
) as edit_session, self.mcp_client.session(
|
||||
"claude-context"
|
||||
) as cc_session:
|
||||
sessions = {
|
||||
"filesystem": fs_session,
|
||||
"edit": edit_session,
|
||||
"claude-context": cc_session,
|
||||
}
|
||||
yield await self._load_tools_from_sessions(sessions)
|
||||
else: # grep
|
||||
async with self.mcp_client.session(
|
||||
"filesystem"
|
||||
) as fs_session, self.mcp_client.session(
|
||||
"edit"
|
||||
) as edit_session, self.mcp_client.session(
|
||||
"grep"
|
||||
) as grep_session:
|
||||
sessions = {
|
||||
"filesystem": fs_session,
|
||||
"edit": edit_session,
|
||||
"grep": grep_session,
|
||||
}
|
||||
yield await self._load_tools_from_sessions(sessions)
|
||||
else: # all 4 sessions
|
||||
async with self.mcp_client.session(
|
||||
"filesystem"
|
||||
) as fs_session, self.mcp_client.session(
|
||||
"edit"
|
||||
) as edit_session, self.mcp_client.session(
|
||||
"claude-context"
|
||||
) as cc_session, self.mcp_client.session(
|
||||
"grep"
|
||||
) as grep_session:
|
||||
sessions = {
|
||||
"filesystem": fs_session,
|
||||
"edit": edit_session,
|
||||
"claude-context": cc_session,
|
||||
"grep": grep_session,
|
||||
}
|
||||
yield await self._load_tools_from_sessions(sessions)
|
||||
|
||||
async def _load_tools_from_sessions(self, sessions: Dict):
|
||||
"""Load tools from the provided sessions"""
|
||||
fs_tools = await load_mcp_tools(sessions["filesystem"])
|
||||
edit_tools = await load_mcp_tools(sessions["edit"])
|
||||
|
||||
# Get basic tools
|
||||
edit_tool = next((tool for tool in edit_tools if tool.name == "edit"), None,)
|
||||
|
||||
# Start with filesystem tools
|
||||
search_tools = [
|
||||
tool
|
||||
for tool in fs_tools
|
||||
if tool.name in ["read_file", "list_directory", "directory_tree"]
|
||||
]
|
||||
|
||||
# Add edit tool
|
||||
if edit_tool:
|
||||
search_tools.append(edit_tool)
|
||||
|
||||
# Initialize CC-specific tools
|
||||
cc_tools = {
|
||||
"index_tool": None,
|
||||
"indexing_status_tool": None,
|
||||
"clear_index_tool": None,
|
||||
"search_code_tool": None,
|
||||
}
|
||||
|
||||
# Load CC tools if needed
|
||||
if "cc" in self.retrieval_types and "claude-context" in sessions:
|
||||
cc_tool_list = await load_mcp_tools(sessions["claude-context"])
|
||||
|
||||
cc_tools["index_tool"] = next(
|
||||
(tool for tool in cc_tool_list if tool.name == "index_codebase"), None
|
||||
)
|
||||
cc_tools["indexing_status_tool"] = next(
|
||||
(tool for tool in cc_tool_list if tool.name == "get_indexing_status"),
|
||||
None,
|
||||
)
|
||||
cc_tools["clear_index_tool"] = next(
|
||||
(tool for tool in cc_tool_list if tool.name == "clear_index"), None
|
||||
)
|
||||
cc_tools["search_code_tool"] = next(
|
||||
(tool for tool in cc_tool_list if tool.name == "search_code"), None
|
||||
)
|
||||
|
||||
# Add search code tool to search tools
|
||||
if cc_tools["search_code_tool"]:
|
||||
search_tools.append(cc_tools["search_code_tool"])
|
||||
|
||||
# Load Grep tools if needed
|
||||
if "grep" in self.retrieval_types and "grep" in sessions:
|
||||
grep_tools = await load_mcp_tools(sessions["grep"])
|
||||
|
||||
# Add grep tool (typically the first one is search_text)
|
||||
if grep_tools:
|
||||
search_tools.append(grep_tools[0])
|
||||
|
||||
# Return tools as a dictionary for easy access
|
||||
return {
|
||||
"search_tools": search_tools,
|
||||
**cc_tools,
|
||||
}
|
||||
|
||||
def build_index(self, repo_path: str) -> Any:
|
||||
asyncio.run(self.async_build_index(repo_path))
|
||||
|
||||
async def async_build_index(self, repo_path: str) -> Any:
|
||||
"""Build index only if CC is enabled"""
|
||||
if "cc" not in self.retrieval_types:
|
||||
return
|
||||
|
||||
async with self.mcp_sessions_context() as tools:
|
||||
index_tool = tools["index_tool"]
|
||||
indexing_status_tool = tools["indexing_status_tool"]
|
||||
clear_index_tool = tools["clear_index_tool"]
|
||||
|
||||
if not index_tool or not indexing_status_tool or not clear_index_tool:
|
||||
raise RuntimeError("CC tools not found in MCP sessions")
|
||||
|
||||
try:
|
||||
await index_tool.ainvoke(
|
||||
{
|
||||
"path": repo_path,
|
||||
"force": False,
|
||||
"splitter": "ast",
|
||||
"customExtensions": [],
|
||||
"ignorePatterns": [],
|
||||
}
|
||||
)
|
||||
while True:
|
||||
status = await indexing_status_tool.ainvoke({"path": repo_path,})
|
||||
if "fully indexed and ready for search" in status:
|
||||
break
|
||||
time.sleep(2)
|
||||
# For strong consistency, wait for a while before searching
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
logger.error(f"Error building index: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
await clear_index_tool.ainvoke(
|
||||
{"path": repo_path,}
|
||||
)
|
||||
# For strong consistency, wait for a while before searching
|
||||
time.sleep(5)
|
||||
logger.info(f"Cleared index for {repo_path}")
|
||||
raise e
|
||||
|
||||
def search(self, repo_path: str, issue: str, k: int = 20) -> tuple:
|
||||
return asyncio.run(self.async_search(repo_path, issue, k))
|
||||
|
||||
async def async_search(self, repo_path: str, issue: str, k: int = 20) -> tuple:
|
||||
async with self.mcp_sessions_context() as tools:
|
||||
search_tools = tools["search_tools"]
|
||||
evaluator = Evaluator(self.llm_model, search_tools)
|
||||
query = self.prompt.format(repo_path=repo_path, issue=issue)
|
||||
|
||||
try:
|
||||
(
|
||||
conversation_summary,
|
||||
token_usage,
|
||||
file_paths,
|
||||
tool_stats,
|
||||
) = await evaluator.async_run(query, repo_path)
|
||||
finally:
|
||||
# Clear index if CC is enabled
|
||||
if "cc" in self.retrieval_types:
|
||||
clear_index_tool = tools["clear_index_tool"]
|
||||
if clear_index_tool:
|
||||
try:
|
||||
await clear_index_tool.ainvoke(
|
||||
{"path": repo_path,}
|
||||
)
|
||||
# For strong consistency, wait for a while before searching
|
||||
time.sleep(3)
|
||||
logger.info(f"Cleared index for {repo_path}")
|
||||
except Exception as clear_error:
|
||||
logger.warning(
|
||||
f"Failed to clear index for {repo_path}: {clear_error}"
|
||||
)
|
||||
|
||||
return file_paths, token_usage, conversation_summary, tool_stats
|
||||
|
||||
def run(self, root_dir: str, token: str = "git") -> None:
|
||||
asyncio.run(self.async_run(root_dir, token))
|
||||
|
||||
async def async_run(self, root_dir: str, token: str = "git") -> None:
|
||||
for instance in tqdm(self.instances, desc="Running retrieval"):
|
||||
instance_id = instance["instance_id"]
|
||||
repo = instance["repo"]
|
||||
commit = instance["base_commit"]
|
||||
issue = instance["problem_statement"]
|
||||
|
||||
# Create instance directory
|
||||
instance_dir = os.path.join(self.output_dir, instance_id)
|
||||
os.makedirs(instance_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
repo_dir = clone_repo(repo, root_dir, token)
|
||||
|
||||
with ContextManager(str(repo_dir), commit):
|
||||
logger.info(f"Building index for {instance_id}")
|
||||
await self.async_build_index(str(repo_dir))
|
||||
|
||||
logger.info(f"Searching for {instance_id}")
|
||||
(
|
||||
hits,
|
||||
token_usage,
|
||||
conversation_summary,
|
||||
tool_stats,
|
||||
) = await self.async_search(repo_dir, issue, k=20)
|
||||
|
||||
# Extract oracle files from patch
|
||||
oracles = extract_oracle_files_from_patch(instance.get("patch", ""))
|
||||
|
||||
# Prepare result data
|
||||
result = {
|
||||
"instance_id": instance_id,
|
||||
"hits": hits,
|
||||
"oracles": oracles,
|
||||
"token_usage": token_usage,
|
||||
"tool_stats": tool_stats,
|
||||
"retrieval_types": self.retrieval_types, # Add info about which retrieval types were used
|
||||
}
|
||||
|
||||
# Save result and token info to JSON file
|
||||
result_file = os.path.join(instance_dir, "result.json")
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
|
||||
# Save conversation log
|
||||
log_file = os.path.join(instance_dir, "conversation.log")
|
||||
with open(log_file, "w") as f:
|
||||
f.write(conversation_summary)
|
||||
|
||||
# Create unified diff file from conversation log
|
||||
try:
|
||||
create_unified_diff_file(instance_dir, conversation_summary)
|
||||
logger.info(f"Created unified diff file for {instance_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to create unified diff file for {instance_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Retrieval completed for {instance_id}. Results saved to {instance_dir}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Save error stack trace to error.log
|
||||
error_file = os.path.join(instance_dir, "error.log")
|
||||
with open(error_file, "w") as f:
|
||||
f.write(f"Error processing {instance_id}: {e}\n\n")
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
logger.error(f"Error processing {instance_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
continue
|
||||
123
evaluation/run_evaluation.py
Normal file
123
evaluation/run_evaluation.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from typing import List, Optional
|
||||
|
||||
from retrieval.custom import CustomRetrieval
|
||||
from utils.constant import evaluation_path, project_path
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(
|
||||
dataset_name_or_path: str,
|
||||
output_dir: str,
|
||||
retrieval_types: List[str],
|
||||
llm_type: str = "openai",
|
||||
llm_model: Optional[str] = None,
|
||||
splits: List[str] = ["test"],
|
||||
root_dir: str = str(evaluation_path / "repos"),
|
||||
max_instances: Optional[int] = 5,
|
||||
):
|
||||
"""
|
||||
Main function to run custom retrieval.
|
||||
|
||||
Args:
|
||||
dataset_name_or_path: Dataset path or name
|
||||
output_dir: Output directory for results
|
||||
retrieval_types: List of retrieval types to use ('cc', 'grep', or both)
|
||||
llm_type: Type of LLM to use
|
||||
llm_model: LLM model name
|
||||
splits: Dataset splits to process
|
||||
root_dir: Root directory for repositories
|
||||
max_instances: Maximum number of instances to process
|
||||
"""
|
||||
logger.info(f"Starting custom retrieval with types: {retrieval_types}")
|
||||
|
||||
retrieval = CustomRetrieval(
|
||||
dataset_name_or_path=dataset_name_or_path,
|
||||
splits=splits,
|
||||
output_dir=output_dir,
|
||||
retrieval_types=retrieval_types,
|
||||
llm_type=llm_type,
|
||||
llm_model=llm_model,
|
||||
max_instances=max_instances,
|
||||
)
|
||||
|
||||
retrieval.run(root_dir, token=os.environ.get("GITHUB_TOKEN", "git"))
|
||||
|
||||
|
||||
def parse_retrieval_types(value: str) -> List[str]:
|
||||
"""Parse comma-separated retrieval types string into list"""
|
||||
types = [t.strip().lower() for t in value.split(",")]
|
||||
valid_types = {"cc", "grep"}
|
||||
|
||||
for t in types:
|
||||
if t not in valid_types:
|
||||
raise ValueError(
|
||||
f"Invalid retrieval type '{t}'. Must be one of: {valid_types}"
|
||||
)
|
||||
|
||||
return types
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(
|
||||
description="Custom Retrieval for SWE-bench with flexible retrieval types"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name_or_path",
|
||||
type=str,
|
||||
# default="SWE-bench/SWE-bench_Lite",
|
||||
default="swe_verified_15min1h_2files_instances.json",
|
||||
help="Dataset name or path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=str(evaluation_path / "retrieval_results_custom"),
|
||||
help="Output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retrieval_types",
|
||||
type=parse_retrieval_types,
|
||||
default="cc,grep",
|
||||
help="Comma-separated list of retrieval types to use. Options: 'cc', 'grep', or 'cc,grep' (default: 'cc,grep')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_type",
|
||||
type=str,
|
||||
choices=["openai", "ollama", "moonshot"],
|
||||
# default="moonshot",
|
||||
default="openai",
|
||||
# default="anthropic",
|
||||
help="LLM type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_model",
|
||||
type=str,
|
||||
# default="kimi-k2-0711-preview",
|
||||
default="gpt-4o-mini",
|
||||
# default="claude-sonnet-4-20250514",
|
||||
help="LLM model name, e.g. gpt-4o-mini",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--splits", nargs="+", default=["test"], help="Dataset splits to process"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=str,
|
||||
default=str(evaluation_path / "repos"),
|
||||
help="Temporary directory for repositories",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_instances",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum number of instances to process (default: 5, set to -1 for all)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
0
evaluation/servers/__init__.py
Normal file
0
evaluation/servers/__init__.py
Normal file
38
evaluation/servers/edit_server.py
Normal file
38
evaluation/servers/edit_server.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
An edit server using MCP (Model Context Protocol).
|
||||
This server provides file editing functionality for modifying files.
|
||||
"""
|
||||
|
||||
import os
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Create the MCP server
|
||||
mcp = FastMCP("Edit Server")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def edit(file_path: str, old_string: str, new_string: str) -> str:
|
||||
"""Edits the specified file with the given modifications.
|
||||
|
||||
This tool marks files that need to be edited with the specified changes.
|
||||
|
||||
Args:
|
||||
file_path: The absolute path to the file to modify.
|
||||
old_string: The exact literal text to replace. Must uniquely identify the single
|
||||
instance to change. Should include at least 3 lines of context before
|
||||
and after the target text, matching whitespace and indentation precisely.
|
||||
If old_string is empty, the tool attempts to create a new file at
|
||||
file_path with new_string as content.
|
||||
new_string: The exact literal text to replace old_string with.
|
||||
|
||||
Returns:
|
||||
A string indicating the file has been successfully modified.
|
||||
"""
|
||||
# Mock the edit operation
|
||||
return f"Successfully modified file: {file_path}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the server with stdio transport
|
||||
mcp.run(transport="stdio")
|
||||
217
evaluation/servers/grep_server.py
Normal file
217
evaluation/servers/grep_server.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A grep server using MCP (Model Context Protocol).
|
||||
This server provides grep functionality to search for regular expression patterns within files.
|
||||
|
||||
Implementation logic inspired by Gemini CLI's grep.ts:
|
||||
https://github.com/google-gemini/gemini-cli/blob/main/packages/core/src/tools/grep.ts
|
||||
Adapted from TypeScript to Python implementation with similar fallback strategy.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Dict, Any, Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Create the MCP server
|
||||
mcp = FastMCP("Grep Server")
|
||||
|
||||
|
||||
def is_git_repository(path: str) -> bool:
|
||||
"""Check if the given path is inside a git repository."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--git-dir"],
|
||||
cwd=path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.SubprocessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def search_text(
|
||||
pattern: str, path: Optional[str] = None, include: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression (regex) pattern to search for within file contents (e.g., 'function\\s+myFunction', 'import\\s+\\{.*\\}\\s+from\\s+.*').
|
||||
path: Optional: The absolute path to the directory to search within. If omitted, searches the current working directory.
|
||||
include: Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).
|
||||
|
||||
Returns:
|
||||
A dictionary containing search results with file paths, line numbers, and matching lines.
|
||||
"""
|
||||
# Use current working directory if no path specified
|
||||
search_path = path if path else os.getcwd()
|
||||
|
||||
# Validate that the search path exists
|
||||
if not os.path.exists(search_path):
|
||||
return {"error": f"Path does not exist: {search_path}", "matches": []}
|
||||
|
||||
try:
|
||||
# Check if we're in a git repository and try git grep first
|
||||
if is_git_repository(search_path):
|
||||
try:
|
||||
# Build git grep command
|
||||
git_cmd = ["git", "grep", "-n", "-E"]
|
||||
|
||||
# Add include pattern if specified (git grep uses different syntax)
|
||||
if include:
|
||||
git_cmd.extend(["--", include])
|
||||
else:
|
||||
git_cmd.append("--")
|
||||
|
||||
# Add pattern
|
||||
git_cmd.insert(-1, pattern) # Insert pattern before the "--" separator
|
||||
|
||||
# Execute git grep command
|
||||
result = subprocess.run(
|
||||
git_cmd,
|
||||
cwd=search_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="ignore",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# If git grep succeeds, use its output
|
||||
if result.returncode == 0:
|
||||
# Parse git grep output and return results
|
||||
matches = []
|
||||
if result.stdout:
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if ":" in line:
|
||||
# Parse git grep output format: filepath:line_number:content
|
||||
parts = line.split(":", 2)
|
||||
if len(parts) >= 3:
|
||||
file_path = parts[0]
|
||||
try:
|
||||
line_number = int(parts[1])
|
||||
line_content = parts[2]
|
||||
matches.append(
|
||||
{
|
||||
"file": os.path.join(
|
||||
search_path, file_path
|
||||
)
|
||||
if not os.path.isabs(file_path)
|
||||
else file_path,
|
||||
"line_number": line_number,
|
||||
"line_content": line_content,
|
||||
"match": pattern,
|
||||
}
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return {
|
||||
"pattern": pattern,
|
||||
"search_path": search_path,
|
||||
"total_matches": len(matches),
|
||||
"matches": matches,
|
||||
"command": " ".join(git_cmd),
|
||||
"method": "git grep",
|
||||
}
|
||||
|
||||
except (
|
||||
subprocess.SubprocessError,
|
||||
subprocess.TimeoutExpired,
|
||||
FileNotFoundError,
|
||||
):
|
||||
# Git grep failed, fall back to regular grep
|
||||
pass
|
||||
|
||||
# Fallback: Build regular grep command
|
||||
cmd = [
|
||||
"grep",
|
||||
"-n",
|
||||
"-r",
|
||||
"-E",
|
||||
] # -n for line numbers, -r for recursive, -E for extended regex
|
||||
|
||||
# Add include pattern if specified
|
||||
if include:
|
||||
cmd.extend(["--include", include])
|
||||
|
||||
# Add common exclusions
|
||||
cmd.extend(
|
||||
[
|
||||
"--exclude-dir=.git",
|
||||
"--exclude-dir=node_modules",
|
||||
"--exclude-dir=__pycache__",
|
||||
"--exclude-dir=.svn",
|
||||
"--exclude-dir=.hg",
|
||||
"--exclude-dir=venv",
|
||||
"--exclude-dir=env",
|
||||
"--exclude=*.pyc",
|
||||
"--exclude=*.pyo",
|
||||
"--exclude=*.so",
|
||||
"--exclude=*.dll",
|
||||
"--exclude=*.exe",
|
||||
"--exclude=*.jpg",
|
||||
"--exclude=*.jpeg",
|
||||
"--exclude=*.png",
|
||||
"--exclude=*.gif",
|
||||
"--exclude=*.zip",
|
||||
"--exclude=*.tar",
|
||||
"--exclude=*.gz",
|
||||
"--exclude=*.pdf",
|
||||
"--exclude=*.wasm",
|
||||
]
|
||||
)
|
||||
|
||||
# Add pattern and search path
|
||||
cmd.extend([pattern, search_path])
|
||||
|
||||
# Execute grep command
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, encoding="utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Parse grep output
|
||||
matches = []
|
||||
if result.stdout:
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if ":" in line:
|
||||
# Parse grep output format: filepath:line_number:content
|
||||
parts = line.split(":", 2)
|
||||
if len(parts) >= 3:
|
||||
file_path = parts[0]
|
||||
try:
|
||||
line_number = int(parts[1])
|
||||
line_content = parts[2]
|
||||
matches.append(
|
||||
{
|
||||
"file": file_path,
|
||||
"line_number": line_number,
|
||||
"line_content": line_content,
|
||||
"match": pattern, # grep already matched, so pattern is the match
|
||||
}
|
||||
)
|
||||
except ValueError:
|
||||
# Skip malformed lines
|
||||
continue
|
||||
|
||||
return {
|
||||
"pattern": pattern,
|
||||
"search_path": search_path,
|
||||
"total_matches": len(matches),
|
||||
"matches": matches,
|
||||
"command": " ".join(cmd), # Include the actual command for debugging
|
||||
"method": "system grep",
|
||||
}
|
||||
|
||||
except subprocess.SubprocessError as e:
|
||||
return {"error": f"Grep command failed: {str(e)}", "matches": []}
|
||||
except Exception as e:
|
||||
return {"error": f"Unexpected error: {str(e)}", "matches": []}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the server with stdio transport
|
||||
mcp.run(transport="stdio")
|
||||
272
evaluation/servers/read_server.py
Normal file
272
evaluation/servers/read_server.py
Normal file
@@ -0,0 +1,272 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A read_file server using MCP (Model Context Protocol).
|
||||
This server provides file reading functionality for text files.
|
||||
|
||||
Implementation logic inspired by Gemini CLI's read-file.ts:
|
||||
https://github.com/google-gemini/gemini-cli/blob/main/packages/core/src/tools/read-file.ts
|
||||
Adapted from TypeScript to Python implementation with text file handling.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Create the MCP server
|
||||
mcp = FastMCP("Read File Server")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def read_file(
|
||||
path: str, offset: Optional[int] = None, limit: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Reads the content of a text file at the specified path.
|
||||
|
||||
You can optionally specify an offset and limit to read only a portion of the file.
|
||||
|
||||
Args:
|
||||
path: The absolute path to the file to read.
|
||||
offset: Optional: The line number to start reading from (0-based).
|
||||
limit: Optional: The maximum number of lines to read.
|
||||
|
||||
Returns:
|
||||
A dictionary containing either:
|
||||
- For text files: {"content": "file content as string", "type": "text", "total_lines": number}
|
||||
- For errors: {"error": "error message"}
|
||||
"""
|
||||
try:
|
||||
# Validate path is absolute
|
||||
if not os.path.isabs(path):
|
||||
return {"error": f"Path must be absolute: {path}"}
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(path):
|
||||
return {"error": f"File does not exist: {path}"}
|
||||
|
||||
# Check if it's actually a file
|
||||
if os.path.isdir(path):
|
||||
return {"error": f"Path is a directory, not a file: {path}"}
|
||||
|
||||
# Get file extension
|
||||
_, ext = os.path.splitext(path.lower())
|
||||
|
||||
# Try to read as text file
|
||||
try:
|
||||
# Try to read as text with UTF-8 encoding
|
||||
with open(path, "r", encoding="utf-8", errors="replace") as file:
|
||||
lines = file.readlines()
|
||||
total_lines = len(lines)
|
||||
|
||||
if offset is not None and limit is not None:
|
||||
# Validate offset
|
||||
if offset < 0:
|
||||
return {"error": f"Offset must be non-negative: {offset}"}
|
||||
if offset >= total_lines:
|
||||
return {
|
||||
"error": f"Offset {offset} is beyond file length {total_lines}"
|
||||
}
|
||||
|
||||
# Calculate end position
|
||||
start = offset
|
||||
end = min(offset + limit, total_lines)
|
||||
content = "".join(lines[start:end])
|
||||
|
||||
# Add truncation notice if needed
|
||||
if end < total_lines:
|
||||
content = (
|
||||
f"[File content truncated: showing lines {start + 1}-{end} of {total_lines} total lines...]\n"
|
||||
+ content
|
||||
)
|
||||
else:
|
||||
content = "".join(lines)
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"type": "text",
|
||||
"total_lines": total_lines,
|
||||
"path": path,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
# If UTF-8 fails, try other common encodings
|
||||
for encoding in ["latin-1", "cp1252", "iso-8859-1"]:
|
||||
try:
|
||||
with open(path, "r", encoding=encoding, errors="replace") as file:
|
||||
lines = file.readlines()
|
||||
total_lines = len(lines)
|
||||
|
||||
if offset is not None and limit is not None:
|
||||
if offset < 0:
|
||||
return {
|
||||
"error": f"Offset must be non-negative: {offset}"
|
||||
}
|
||||
if offset >= total_lines:
|
||||
return {
|
||||
"error": f"Offset {offset} is beyond file length {total_lines}"
|
||||
}
|
||||
|
||||
start = offset
|
||||
end = min(offset + limit, total_lines)
|
||||
content = "".join(lines[start:end])
|
||||
|
||||
if end < total_lines:
|
||||
content = (
|
||||
f"[File content truncated: showing lines {start + 1}-{end} of {total_lines} total lines...]\n"
|
||||
+ content
|
||||
)
|
||||
else:
|
||||
content = "".join(lines)
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"type": "text",
|
||||
"total_lines": total_lines,
|
||||
"path": path,
|
||||
"encoding": encoding,
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
# If all encodings fail, treat as binary
|
||||
return {"error": f"Cannot read file as text (encoding issues): {path}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Unexpected error reading file: {str(e)}"}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def list_directory(path: str) -> Dict[str, Any]:
|
||||
"""Lists the contents of a directory.
|
||||
|
||||
Args:
|
||||
path: The absolute path to the directory to list.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- For success: {"entries": [{"name": "...", "type": "file|directory", "size": number}], "path": "..."}
|
||||
- For errors: {"error": "error message"}
|
||||
"""
|
||||
try:
|
||||
# Validate path is absolute
|
||||
if not os.path.isabs(path):
|
||||
return {"error": f"Path must be absolute: {path}"}
|
||||
|
||||
# Check if directory exists
|
||||
if not os.path.exists(path):
|
||||
return {"error": f"Directory does not exist: {path}"}
|
||||
|
||||
# Check if it's actually a directory
|
||||
if not os.path.isdir(path):
|
||||
return {"error": f"Path is not a directory: {path}"}
|
||||
|
||||
entries = []
|
||||
for item in os.listdir(path):
|
||||
item_path = os.path.join(path, item)
|
||||
try:
|
||||
if os.path.isfile(item_path):
|
||||
size = os.path.getsize(item_path)
|
||||
entries.append({"name": item, "type": "file", "size": size})
|
||||
elif os.path.isdir(item_path):
|
||||
entries.append({"name": item, "type": "directory", "size": 0})
|
||||
except (OSError, PermissionError):
|
||||
# Skip items we can't access
|
||||
continue
|
||||
|
||||
# Sort entries: directories first, then files, both alphabetically
|
||||
entries.sort(key=lambda x: (x["type"] == "file", x["name"].lower()))
|
||||
|
||||
return {"entries": entries, "path": path, "total_count": len(entries)}
|
||||
|
||||
except PermissionError:
|
||||
return {"error": f"Permission denied accessing directory: {path}"}
|
||||
except Exception as e:
|
||||
return {"error": f"Unexpected error listing directory: {str(e)}"}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def directory_tree(path: str, max_depth: Optional[int] = 3) -> Dict[str, Any]:
|
||||
"""Generates a tree structure of a directory.
|
||||
|
||||
Args:
|
||||
path: The absolute path to the directory to generate tree for.
|
||||
max_depth: Optional: Maximum depth to traverse (default: 3).
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- For success: {"tree": "tree structure as string", "path": "..."}
|
||||
- For errors: {"error": "error message"}
|
||||
"""
|
||||
try:
|
||||
# Validate path is absolute
|
||||
if not os.path.isabs(path):
|
||||
return {"error": f"Path must be absolute: {path}"}
|
||||
|
||||
# Check if directory exists
|
||||
if not os.path.exists(path):
|
||||
return {"error": f"Directory does not exist: {path}"}
|
||||
|
||||
# Check if it's actually a directory
|
||||
if not os.path.isdir(path):
|
||||
return {"error": f"Path is not a directory: {path}"}
|
||||
|
||||
def build_tree(current_path: str, prefix: str = "", depth: int = 0) -> str:
|
||||
if max_depth and depth >= max_depth:
|
||||
return ""
|
||||
|
||||
tree_str = ""
|
||||
try:
|
||||
items = sorted(os.listdir(current_path))
|
||||
for i, item in enumerate(items):
|
||||
item_path = os.path.join(current_path, item)
|
||||
is_last = i == len(items) - 1
|
||||
|
||||
# Skip hidden files and common ignore patterns
|
||||
if item.startswith(".") and item not in [
|
||||
".env",
|
||||
".gitignore",
|
||||
".gitattributes",
|
||||
]:
|
||||
continue
|
||||
if item in [
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
"venv",
|
||||
"env",
|
||||
]:
|
||||
continue
|
||||
|
||||
try:
|
||||
if os.path.isdir(item_path):
|
||||
tree_str += (
|
||||
f"{prefix}{'└── ' if is_last else '├── '}{item}/\n"
|
||||
)
|
||||
extension = " " if is_last else "│ "
|
||||
tree_str += build_tree(
|
||||
item_path, prefix + extension, depth + 1
|
||||
)
|
||||
else:
|
||||
tree_str += (
|
||||
f"{prefix}{'└── ' if is_last else '├── '}{item}\n"
|
||||
)
|
||||
except (OSError, PermissionError):
|
||||
# Skip items we can't access
|
||||
continue
|
||||
except (OSError, PermissionError):
|
||||
pass
|
||||
|
||||
return tree_str
|
||||
|
||||
tree_structure = f"{os.path.basename(path)}/\n"
|
||||
tree_structure += build_tree(path)
|
||||
|
||||
return {"tree": tree_structure, "path": path, "max_depth": max_depth}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Unexpected error generating directory tree: {str(e)}"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the server with stdio transport
|
||||
mcp.run(transport="stdio")
|
||||
0
evaluation/utils/__init__.py
Normal file
0
evaluation/utils/__init__.py
Normal file
4
evaluation/utils/constant.py
Normal file
4
evaluation/utils/constant.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from pathlib import Path
|
||||
|
||||
evaluation_path = Path(__file__).parent.parent.absolute() # evaluation/
|
||||
project_path = evaluation_path.parent.absolute() # claude-context/
|
||||
129
evaluation/utils/file_management.py
Normal file
129
evaluation/utils/file_management.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
import logging
|
||||
from git import Repo
|
||||
from filelock import FileLock
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_remaining_instances(instances, output_file):
|
||||
"""
|
||||
Filters a list of instances to exclude those that have already been processed and saved in a file.
|
||||
|
||||
Args:
|
||||
instances (List[Dict]): A list of instances, where each instance is a dictionary with an "instance_id" key.
|
||||
output_file (Path): The path to the file where the processed instances are saved.
|
||||
|
||||
Returns:
|
||||
List[Dict]: A list of instances that have not been processed yet.
|
||||
"""
|
||||
instance_ids = set()
|
||||
remaining_instances = list()
|
||||
if output_file.exists():
|
||||
with FileLock(output_file.as_posix() + ".lock"):
|
||||
with open(output_file) as f:
|
||||
for line in f:
|
||||
instance = json.loads(line)
|
||||
instance_id = instance["instance_id"]
|
||||
instance_ids.add(instance_id)
|
||||
logger.warning(
|
||||
f"Found {len(instance_ids)} existing instances in {output_file}. Will skip them."
|
||||
)
|
||||
else:
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
return instances
|
||||
for instance in instances:
|
||||
instance_id = instance["instance_id"]
|
||||
if instance_id not in instance_ids:
|
||||
remaining_instances.append(instance)
|
||||
return remaining_instances
|
||||
|
||||
|
||||
def is_test(name, test_phrases=None):
|
||||
if test_phrases is None:
|
||||
test_phrases = ["test", "tests", "testing"]
|
||||
words = set(re.split(r" |_|\/|\.", name.lower()))
|
||||
return any(word in words for word in test_phrases)
|
||||
|
||||
|
||||
def list_files(root_dir, include_tests=False):
|
||||
files = []
|
||||
for filename in Path(root_dir).rglob("*.py"):
|
||||
if not include_tests and is_test(filename.as_posix()):
|
||||
continue
|
||||
files.append(filename.relative_to(root_dir).as_posix())
|
||||
return files
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""
|
||||
A context manager for managing a Git repository at a specific commit.
|
||||
|
||||
Args:
|
||||
repo_path (str): The path to the Git repository.
|
||||
base_commit (str): The commit hash to switch to.
|
||||
verbose (bool, optional): Whether to print verbose output. Defaults to False.
|
||||
|
||||
Attributes:
|
||||
repo_path (str): The path to the Git repository.
|
||||
base_commit (str): The commit hash to switch to.
|
||||
verbose (bool): Whether to print verbose output.
|
||||
repo (git.Repo): The Git repository object.
|
||||
|
||||
Methods:
|
||||
__enter__(): Switches to the specified commit and returns the context manager object.
|
||||
get_readme_files(): Returns a list of filenames for all README files in the repository.
|
||||
__exit__(exc_type, exc_val, exc_tb): Does nothing.
|
||||
"""
|
||||
|
||||
def __init__(self, repo_path, base_commit, verbose=False):
|
||||
self.repo_path = Path(repo_path).resolve().as_posix()
|
||||
self.base_commit = base_commit
|
||||
self.verbose = verbose
|
||||
self.repo = Repo(self.repo_path)
|
||||
|
||||
def __enter__(self):
|
||||
if self.verbose:
|
||||
print(f"Switching to {self.base_commit}")
|
||||
try:
|
||||
self.repo.git.reset("--hard", self.base_commit)
|
||||
self.repo.git.clean("-fdxq")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to switch to {self.base_commit}")
|
||||
logger.error(e)
|
||||
raise e
|
||||
return self
|
||||
|
||||
def get_readme_files(self):
|
||||
files = os.listdir(self.repo_path)
|
||||
files = list(filter(lambda x: os.path.isfile(x), files))
|
||||
files = list(filter(lambda x: x.lower().startswith("readme"), files))
|
||||
return files
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
def clone_repo(repo, root_dir, token):
|
||||
"""
|
||||
Clones a GitHub repository to a specified directory.
|
||||
|
||||
Args:
|
||||
repo (str): The GitHub repository to clone.
|
||||
root_dir (str): The root directory to clone the repository to.
|
||||
token (str): The GitHub personal access token to use for authentication.
|
||||
|
||||
Returns:
|
||||
Path: The path to the cloned repository directory.
|
||||
"""
|
||||
repo_dir = Path(root_dir, f"repo__{repo.replace('/', '__')}")
|
||||
|
||||
if not repo_dir.exists():
|
||||
repo_url = f"https://{token}@github.com/{repo}.git"
|
||||
logger.info(f"Cloning {repo} {os.getpid()}")
|
||||
Repo.clone_from(repo_url, repo_dir)
|
||||
return repo_dir
|
||||
451
evaluation/utils/format.py
Normal file
451
evaluation/utils/format.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
|
||||
|
||||
def extract_final_answer(response):
|
||||
"""Extract the final answer from the agent response"""
|
||||
if "messages" in response:
|
||||
messages = response["messages"]
|
||||
# Get the last AI message
|
||||
for message in reversed(messages):
|
||||
if hasattr(message, "content") and isinstance(message.content, str):
|
||||
return message.content
|
||||
elif hasattr(message, "content") and isinstance(message.content, list):
|
||||
# Handle structured content
|
||||
for content_item in message.content:
|
||||
if (
|
||||
isinstance(content_item, dict)
|
||||
and content_item.get("type") == "text"
|
||||
):
|
||||
return content_item.get("text", "")
|
||||
return "No answer found"
|
||||
|
||||
|
||||
def extract_file_paths_from_edits(response, codebase_path):
|
||||
"""Extract file paths from edit tool responses and convert to relative paths"""
|
||||
import re
|
||||
|
||||
file_paths = []
|
||||
seen_relative_paths = set() # Use set for faster lookup
|
||||
codebase_path = os.path.abspath(codebase_path)
|
||||
|
||||
# Extract the entire conversation content
|
||||
if hasattr(response, "get") and "messages" in response:
|
||||
# Handle LangGraph response format
|
||||
content = ""
|
||||
for message in response["messages"]:
|
||||
if hasattr(message, "content"):
|
||||
content += str(message.content) + "\n"
|
||||
elif isinstance(message, dict) and "content" in message:
|
||||
content += str(message["content"]) + "\n"
|
||||
else:
|
||||
# Fallback for other response formats
|
||||
content = str(response)
|
||||
|
||||
# Pattern to match "Successfully modified file: /path/to/file"
|
||||
edit_pattern = r"Successfully modified file:\s*(.+?)(?:\s|$)"
|
||||
|
||||
# Also check for edit tool calls in the response
|
||||
# Pattern to match edit tool calls with file_path parameter
|
||||
tool_call_pattern = r"edit.*?file_path[\"']?\s*:\s*[\"']([^\"']+)[\"']"
|
||||
|
||||
for line in content.split("\n"):
|
||||
# Check for "Successfully modified file:" pattern
|
||||
match = re.search(edit_pattern, line.strip())
|
||||
if match:
|
||||
file_path = match.group(1).strip()
|
||||
# Convert to relative path immediately for deduplication
|
||||
rel_path = _normalize_to_relative_path(file_path, codebase_path)
|
||||
if rel_path and rel_path not in seen_relative_paths:
|
||||
seen_relative_paths.add(rel_path)
|
||||
file_paths.append(rel_path)
|
||||
|
||||
# Check for edit tool calls
|
||||
match = re.search(tool_call_pattern, line.strip(), re.IGNORECASE)
|
||||
if match:
|
||||
file_path = match.group(1).strip()
|
||||
# Convert to relative path immediately for deduplication
|
||||
rel_path = _normalize_to_relative_path(file_path, codebase_path)
|
||||
if rel_path and rel_path not in seen_relative_paths:
|
||||
seen_relative_paths.add(rel_path)
|
||||
file_paths.append(rel_path)
|
||||
|
||||
return file_paths
|
||||
|
||||
|
||||
def _normalize_to_relative_path(file_path, codebase_path):
|
||||
"""Convert a file path to relative path based on codebase_path"""
|
||||
if isinstance(file_path, str):
|
||||
if os.path.isabs(file_path):
|
||||
# Absolute path - convert to relative
|
||||
abs_path = os.path.abspath(file_path)
|
||||
if abs_path.startswith(codebase_path):
|
||||
return os.path.relpath(abs_path, codebase_path)
|
||||
else:
|
||||
# Path outside codebase, return as-is
|
||||
return file_path
|
||||
else:
|
||||
# Already relative path
|
||||
return file_path
|
||||
return None
|
||||
|
||||
|
||||
def extract_oracle_files_from_patch(patch):
|
||||
"""Extract the list of oracle files from the patch field"""
|
||||
import re
|
||||
|
||||
if not patch:
|
||||
return []
|
||||
|
||||
# Pattern to match patch headers like "--- a/path/to/file"
|
||||
patch_files_pattern = re.compile(r"\-\-\- a/(.+)")
|
||||
oracle_files = list(set(patch_files_pattern.findall(patch)))
|
||||
|
||||
return oracle_files
|
||||
|
||||
|
||||
def extract_edit_calls_from_conversation_log(log_content: str):
|
||||
"""Extract all edit tool calls from conversation log content"""
|
||||
import re
|
||||
|
||||
edit_calls = []
|
||||
|
||||
# Split content into lines for processing
|
||||
lines = log_content.split("\n")
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# Look for Arguments: line with edit tool (may have leading whitespace)
|
||||
if "Arguments:" in line and "'file_path'" in line:
|
||||
# Collect the full arguments block (might span multiple lines)
|
||||
args_block = line
|
||||
|
||||
# Check if the line contains complete arguments
|
||||
if "}" in line:
|
||||
# Arguments are on a single line
|
||||
args_text = line
|
||||
else:
|
||||
# Arguments span multiple lines
|
||||
j = i + 1
|
||||
while j < len(lines) and "}" not in lines[j]:
|
||||
args_block += (
|
||||
"\n" + lines[j]
|
||||
) # Keep original formatting including newlines
|
||||
j += 1
|
||||
if j < len(lines):
|
||||
args_block += "\n" + lines[j]
|
||||
args_text = args_block
|
||||
|
||||
# Extract file_path, old_string, new_string using regex
|
||||
file_path_match = re.search(r"'file_path':\s*'([^']*)'", args_text)
|
||||
# old_string can be either single-quoted or double-quoted
|
||||
old_string_match = re.search(
|
||||
r"'old_string':\s*[\"'](.*?)[\"'](?=,\s*'new_string')",
|
||||
args_text,
|
||||
re.DOTALL,
|
||||
)
|
||||
# new_string can be either single-quoted or double-quoted
|
||||
new_string_match = re.search(
|
||||
r"'new_string':\s*[\"'](.*?)[\"'](?=\s*})", args_text, re.DOTALL
|
||||
)
|
||||
|
||||
if file_path_match and old_string_match and new_string_match:
|
||||
file_path = file_path_match.group(1)
|
||||
old_string = old_string_match.group(1)
|
||||
new_string = new_string_match.group(1)
|
||||
|
||||
# Unescape newlines and clean up strings
|
||||
old_string = old_string.replace("\\n", "\n").replace("\\'", "'")
|
||||
new_string = new_string.replace("\\n", "\n").replace("\\'", "'")
|
||||
|
||||
edit_calls.append(
|
||||
{
|
||||
"file_path": file_path,
|
||||
"old_string": old_string,
|
||||
"new_string": new_string,
|
||||
}
|
||||
)
|
||||
|
||||
i += 1
|
||||
|
||||
return edit_calls
|
||||
|
||||
|
||||
def find_line_number_for_old_string(file_path: str, old_string: str):
|
||||
"""Find the line number where old_string starts in the file"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the position of old_string in the content
|
||||
pos = content.find(old_string)
|
||||
if pos == -1:
|
||||
return None
|
||||
|
||||
# Count lines up to that position
|
||||
line_num = content[:pos].count("\n") + 1
|
||||
return line_num
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_unified_diff(file_path: str, old_string: str, new_string: str):
|
||||
"""Generate unified diff format for a single edit"""
|
||||
import difflib
|
||||
import os
|
||||
|
||||
# Get the relative file path for cleaner display
|
||||
rel_path = os.path.relpath(file_path) if os.path.exists(file_path) else file_path
|
||||
|
||||
# Find line number where change occurs
|
||||
start_line = find_line_number_for_old_string(file_path, old_string)
|
||||
|
||||
# Split strings into lines for difflib
|
||||
old_lines = old_string.splitlines(keepends=True)
|
||||
new_lines = new_string.splitlines(keepends=True)
|
||||
|
||||
# Generate diff with context
|
||||
diff_lines = list(
|
||||
difflib.unified_diff(
|
||||
old_lines,
|
||||
new_lines,
|
||||
fromfile=f"a/{rel_path}",
|
||||
tofile=f"b/{rel_path}",
|
||||
lineterm="",
|
||||
n=3, # 3 lines of context
|
||||
)
|
||||
)
|
||||
|
||||
# If we found the line number, add it as a comment
|
||||
result = []
|
||||
if start_line is not None:
|
||||
result.append(f"# Edit starting at line {start_line}")
|
||||
|
||||
result.extend(diff_lines)
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def create_unified_diff_file(instance_dir: str, conversation_summary: str) -> None:
|
||||
"""Create a unified diff file from conversation log content"""
|
||||
edit_calls = extract_edit_calls_from_conversation_log(conversation_summary)
|
||||
|
||||
if not edit_calls:
|
||||
return
|
||||
|
||||
diff_content = []
|
||||
diff_content.append("# Unified diff of all edits made during retrieval")
|
||||
diff_content.append("# Generated from conversation log")
|
||||
diff_content.append("")
|
||||
|
||||
for i, edit_call in enumerate(edit_calls, 1):
|
||||
diff_content.append(f"# Edit {i}: {edit_call['file_path']}")
|
||||
diff_content.append("")
|
||||
|
||||
unified_diff = generate_unified_diff(
|
||||
edit_call["file_path"], edit_call["old_string"], edit_call["new_string"]
|
||||
)
|
||||
|
||||
diff_content.append(unified_diff)
|
||||
diff_content.append("")
|
||||
diff_content.append("=" * 80)
|
||||
diff_content.append("")
|
||||
|
||||
# Write to changes.diff file
|
||||
diff_file = os.path.join(instance_dir, "changes.diff")
|
||||
with open(diff_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(diff_content))
|
||||
|
||||
|
||||
def calculate_total_tokens(response):
|
||||
"""Calculate total token usage from the response"""
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
total_tokens = 0
|
||||
max_single_turn_tokens = 0
|
||||
|
||||
if "messages" in response:
|
||||
messages = response["messages"]
|
||||
|
||||
for message in messages:
|
||||
current_turn_tokens = 0
|
||||
|
||||
# Check for usage metadata in AI messages
|
||||
if hasattr(message, "usage_metadata"):
|
||||
usage = message.usage_metadata
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
turn_total = usage.get("total_tokens", input_tokens + output_tokens)
|
||||
|
||||
total_input_tokens += input_tokens
|
||||
total_output_tokens += output_tokens
|
||||
total_tokens += turn_total
|
||||
current_turn_tokens = turn_total
|
||||
|
||||
# Also check response_metadata for additional usage info
|
||||
elif (
|
||||
hasattr(message, "response_metadata")
|
||||
and "usage" in message.response_metadata
|
||||
):
|
||||
usage = message.response_metadata["usage"]
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
|
||||
total_input_tokens += input_tokens
|
||||
total_output_tokens += output_tokens
|
||||
|
||||
# Calculate total if not provided
|
||||
if "total_tokens" in usage:
|
||||
turn_total = usage["total_tokens"]
|
||||
total_tokens += turn_total
|
||||
else:
|
||||
turn_total = input_tokens + output_tokens
|
||||
total_tokens += turn_total
|
||||
|
||||
current_turn_tokens = turn_total
|
||||
|
||||
# Track maximum single turn tokens
|
||||
if current_turn_tokens > max_single_turn_tokens:
|
||||
max_single_turn_tokens = current_turn_tokens
|
||||
|
||||
return {
|
||||
"input_tokens": total_input_tokens,
|
||||
"output_tokens": total_output_tokens,
|
||||
"total_tokens": (
|
||||
total_tokens
|
||||
if total_tokens > 0
|
||||
else total_input_tokens + total_output_tokens
|
||||
),
|
||||
"max_single_turn_tokens": max_single_turn_tokens,
|
||||
}
|
||||
|
||||
|
||||
def print_token_usage(response):
|
||||
"""Print simple token usage statistics"""
|
||||
usage = calculate_total_tokens(response)
|
||||
|
||||
print(f"📥 Input Tokens: {usage['input_tokens']:,}")
|
||||
print(f"📤 Output Tokens: {usage['output_tokens']:,}")
|
||||
print(f"🔢 Total Tokens: {usage['total_tokens']:,}")
|
||||
print(f"🎯 Max Single Turn: {usage['max_single_turn_tokens']:,}")
|
||||
|
||||
|
||||
def truncate_long_content(content, max_lines=30):
|
||||
"""Truncate content if it exceeds max_lines"""
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
lines = content.split("\n")
|
||||
if len(lines) <= max_lines:
|
||||
return content
|
||||
|
||||
truncated = "\n".join(lines[:max_lines])
|
||||
remaining_lines = len(lines) - max_lines
|
||||
return f"{truncated}\n... {remaining_lines} more lines"
|
||||
|
||||
|
||||
def extract_conversation_summary(response):
|
||||
"""Extract conversation summary and return as (summary_string, tool_stats_dict)"""
|
||||
summary_lines = []
|
||||
tool_call_counts = {} # Count of calls for each tool
|
||||
total_tool_calls = 0 # Total number of tool calls
|
||||
|
||||
if "messages" in response:
|
||||
messages = response["messages"]
|
||||
|
||||
summary_lines.append("📝 Conversation Summary:")
|
||||
summary_lines.append("=" * 50)
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if hasattr(message, "content"):
|
||||
if hasattr(message, "role") or "Human" in str(type(message)):
|
||||
# Human message
|
||||
content = (
|
||||
message.content
|
||||
if isinstance(message.content, str)
|
||||
else str(message.content)
|
||||
)
|
||||
summary_lines.append(f"👤 User: {content}")
|
||||
summary_lines.append("=" * 50)
|
||||
|
||||
elif "AI" in str(type(message)):
|
||||
# AI message - extract text content
|
||||
if isinstance(message.content, str):
|
||||
summary_lines.append(f"🤖 LLM: {message.content}")
|
||||
summary_lines.append("=" * 50)
|
||||
elif isinstance(message.content, list):
|
||||
for content_item in message.content:
|
||||
if isinstance(content_item, dict):
|
||||
if content_item.get("type") == "text":
|
||||
summary_lines.append(
|
||||
f"🤖 LLM: {content_item.get('text', '')}"
|
||||
)
|
||||
summary_lines.append("=" * 50)
|
||||
elif content_item.get("type") == "tool_use":
|
||||
tool_name = content_item.get("name", "unknown")
|
||||
tool_input = content_item.get("input", {})
|
||||
tool_id = content_item.get("id", "unknown")
|
||||
|
||||
# Count tool calls
|
||||
tool_call_counts[tool_name] = (
|
||||
tool_call_counts.get(tool_name, 0) + 1
|
||||
)
|
||||
total_tool_calls += 1
|
||||
|
||||
summary_lines.append(f"🔧 Tool Call: '{tool_name}'")
|
||||
summary_lines.append(f" ID: {tool_id}")
|
||||
summary_lines.append(f" Arguments: {tool_input}")
|
||||
summary_lines.append("=" * 50)
|
||||
|
||||
# Also check for tool_calls attribute (LangChain format)
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name = tool_call.get("name", "unknown")
|
||||
tool_args = tool_call.get("args", {})
|
||||
tool_id = tool_call.get("id", "unknown")
|
||||
|
||||
# Count tool calls
|
||||
tool_call_counts[tool_name] = (
|
||||
tool_call_counts.get(tool_name, 0) + 1
|
||||
)
|
||||
total_tool_calls += 1
|
||||
|
||||
summary_lines.append(f"🔧 Tool Call: '{tool_name}'")
|
||||
summary_lines.append(f" ID: {tool_id}")
|
||||
summary_lines.append(f" Arguments: {tool_args}")
|
||||
summary_lines.append("=" * 50)
|
||||
|
||||
elif "Tool" in str(type(message)):
|
||||
# Tool response
|
||||
tool_name = getattr(message, "name", "unknown")
|
||||
tool_call_id = getattr(message, "tool_call_id", "unknown")
|
||||
content = getattr(message, "content", "no result")
|
||||
|
||||
# Truncate long content
|
||||
truncated_content = truncate_long_content(content, max_lines=30)
|
||||
|
||||
summary_lines.append(f"⚙️ Tool Response: '{tool_name}'")
|
||||
summary_lines.append(f" Call ID: {tool_call_id}")
|
||||
summary_lines.append(f" Result: {truncated_content}")
|
||||
summary_lines.append("=" * 50)
|
||||
|
||||
# Build tool statistics
|
||||
tool_stats = {
|
||||
"tool_call_counts": tool_call_counts,
|
||||
"total_tool_calls": total_tool_calls,
|
||||
}
|
||||
|
||||
return "\n".join(summary_lines), tool_stats
|
||||
|
||||
|
||||
def print_conversation_summary(response):
|
||||
"""Print a clean summary of the conversation"""
|
||||
summary, tool_stats = extract_conversation_summary(response)
|
||||
print(summary)
|
||||
print("\n🔧 Tool Usage Statistics:")
|
||||
print(f" Total tool calls: {tool_stats['total_tool_calls']}")
|
||||
if tool_stats["tool_call_counts"]:
|
||||
for tool_name, count in tool_stats["tool_call_counts"].items():
|
||||
print(f" {tool_name}: {count} calls")
|
||||
21
evaluation/utils/llm_factory.py
Normal file
21
evaluation/utils/llm_factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
import os
|
||||
|
||||
|
||||
def llm_factory(llm_type: str, llm_model: str):
|
||||
if llm_type == "openai":
|
||||
return ChatOpenAI(model=llm_model)
|
||||
elif llm_type == "ollama":
|
||||
return ChatOllama(model=llm_model)
|
||||
elif llm_type == "moonshot":
|
||||
return ChatOpenAI(
|
||||
model=llm_model,
|
||||
base_url="https://api.moonshot.cn/v1",
|
||||
api_key=os.getenv("MOONSHOT_API_KEY"),
|
||||
)
|
||||
elif llm_type == "anthropic":
|
||||
return ChatAnthropic(model=llm_model, api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM type: {llm_type}")
|
||||
3210
evaluation/uv.lock
generated
Normal file
3210
evaluation/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user