Add Evaluation (#164)

* build evaluation framework

Signed-off-by: ChengZi <chen.zhang@zilliz.com>

* add

Signed-off-by: ChengZi <chen.zhang@zilliz.com>

* Update evaluation framework

Signed-off-by: Cheney <chen.zhang@zilliz.com>

---------

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Signed-off-by: Cheney <chen.zhang@zilliz.com>
This commit is contained in:
Cheney Zhang
2025-08-18 15:34:27 +08:00
committed by GitHub
parent 25e46ac9b6
commit 188529de44
23 changed files with 5752 additions and 0 deletions

5
.gitignore vendored
View File

@@ -61,3 +61,8 @@ __pycache__/
CLAUDE.md
.cursor/*
evaluation/repos
repos
evaluation/retrieval_results*

View File

@@ -574,6 +574,16 @@ Check the `/examples` directory for complete usage examples:
---
## 📊 Evaluation
Our controlled evaluation demonstrates that Claude Context MCP achieves ~40% token reduction under the condition of equivalent retrieval quality. This translates to significant cost and time savings in production environments. This also means that, under the constraint of limited token context length, using Claude Context yields better retrieval and answer results.
![MCP Efficiency Analysis](assets/mcp_efficiency_analysis_chart.png)
For detailed evaluation methodology and results, see the [evaluation directory](evaluation/).
---
## ❓ FAQ
**Common Questions:**

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

View File

@@ -0,0 +1 @@
3.10

106
evaluation/README.md Normal file
View File

@@ -0,0 +1,106 @@
# Claude Context MCP Evaluation
This directory contains the evaluation framework and experimental results for comparing the efficiency of code retrieval using Claude Context MCP versus traditional grep-only approaches.
## Overview
We conducted a controlled experiment to measure the impact of adding Claude Context MCP tool to a baseline coding agent. The evaluation demonstrates significant improvements in token efficiency while maintaining comparable retrieval quality.
## Experimental Design
We designed a controlled experiment comparing two coding agents performing identical retrieval tasks. The baseline agent uses simple tools including read, grep, and edit functions. The enhanced agent adds Claude Context MCP tool to this same foundation. Both agents work on the same dataset using the same model to ensure fair comparison. We use [LangGraph MCP and ReAct framework](https://langchain-ai.github.io/langgraph/agents/mcp/#use-mcp-tools) to implement it.
We selected 30 instances from Princeton NLP's [SWE-bench_Verified](https://openai.com/index/introducing-swe-bench-verified/) dataset, filtering for 15-60 minute difficulty problems with exactly 2 file modifications. This subset represents typical coding tasks and enables quick validation. The dataset generation is implemented in [`generate_subset_json.py`](./generate_subset_json.py).
We chose [GPT-4o-mini](https://platform.openai.com/docs/models/gpt-4o-mini) as the default model for cost-effective considerations.
We ran each method 3 times independently, giving us 6 total runs for statistical reliability. We measured token usage, tool calls, retrieval precision, recall, and F1-score across all runs. The main entry point for running evaluations is [`run_evaluation.py`](./run_evaluation.py).
## Key Results
### Performance Summary
| Metric | Baseline (Grep Only) | With Claude Context MCP | Improvement |
|--------|---------------------|--------------------------|-------------|
| **Average F1-Score** | 0.40 | 0.40 | Comparable |
| **Average Token Usage** | 73,373 | 44,449 | **-39.4%** |
| **Average Tool Calls** | 8.3 | 5.3 | **-36.3%** |
### Key Findings
**Dramatic Efficiency Gains**:
With Claude Context MCP, we achieved:
- **39.4% reduction** in token consumption (28,924 tokens saved per instance)
- **36.3% reduction** in tool calls (3.0 fewer calls per instance)
## Conclusion
The results demonstrate that Claude Context MCP provides:
### Immediate Benefits
- **Cost Efficiency**: ~40% reduction in token usage directly reduces operational costs
- **Speed Improvement**: Fewer tool calls and tokens mean faster code localization and task completion
- **Better Quality**: This also means that, under the constraint of limited token context length, using Claude Context yields better retrieval and answer results.
### Strategic Advantages
- **Better Resource Utilization**: Under fixed token budgets, Claude Context MCP enables handling more tasks
- **Wider Usage Scenarios**: Lower per-task costs enable broader usage scenarios
- **Improved User Experience**: Faster responses with maintained accuracy
## Running the Evaluation
To reproduce these results:
1. **Install Dependencies**:
For python environment, you can use `uv` to install the lockfile dependencies.
```bash
cd evaluation && uv sync
source .venv/bin/activate
```
For node environment, make sure your `node` version is `Node.js >= 20.0.0 and < 24.0.0`.
Our evaluation results are tested on `claude-context-mcp@0.1.0`, you can change the `claude-context` mcp server setting in the `retrieval/custom.py` file to get the latest version or use a development version.
2. **Set Environment Variables**:
```bash
export OPENAI_API_KEY=your_openai_api_key
export MILVUS_ADDRESS=your_milvus_address
```
For more configuration details, refer the `claude-context` mcp server settings in the `retrieval/custom.py` file.
```bash
export GITHUB_TOKEN=your_github_token
```
You need also prepare a `GITHUB_TOKEN` for automatically cloning the repositories, refer to [SWE-bench documentation](https://www.swebench.com/SWE-bench/guides/create_rag_datasets/#example-usage) for more details.
3. **Generate Dataset**:
```bash
python generate_subset_json.py
```
4. **Run Baseline Evaluation**:
```bash
python run_evaluation.py --retrieval_types grep --output_dir retrieval_results_grep
```
5. **Run Enhanced Evaluation**:
```bash
python run_evaluation.py --retrieval_types cc,grep --output_dir retrieval_results_both
```
6. **Analyze Results**:
```bash
python analyze_and_plot_mcp_efficiency.py
```
The evaluation framework is designed to be reproducible and can be easily extended to test additional configurations or datasets. Due to the proprietary nature of LLMs, exact numerical results may vary between runs and cannot be guaranteed to be identical. However, the core conclusions drawn from the analysis remain consistent and robust across different runs.
## Results Visualization
![MCP Efficiency Analysis](../assets/mcp_efficiency_analysis_chart.png)
*The chart above shows the dramatic efficiency improvements achieved by Claude Context MCP while maintaining equivalent retrieval quality. Token usage and tool calls are significantly reduced with no loss in F1-score performance.*

View File

@@ -0,0 +1,372 @@
#!/usr/bin/env python3
"""
Analyze retrieval results and create MCP efficiency chart using real data.
This script loads data from the actual result directories and generates seaborn charts.
"""
import json
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from typing import Dict, List, Tuple
def normalize_file_path(file_path: str) -> str:
"""Normalize file paths."""
if file_path.startswith("/"):
file_path = file_path[1:]
return file_path
def calculate_metrics(hits: List[str], oracles: List[str]) -> Dict[str, float]:
"""Calculate precision, recall, and F1-score."""
if not hits and not oracles:
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
# Normalize file paths
hits_set = set(normalize_file_path(f) for f in hits)
oracles_set = set(normalize_file_path(f) for f in oracles)
# Calculate intersection
intersection = hits_set.intersection(oracles_set)
# Calculate metrics
precision = len(intersection) / len(hits_set) if hits_set else 0.0
recall = len(intersection) / len(oracles_set) if oracles_set else 0.0
f1 = (
2 * (precision * recall) / (precision + recall)
if (precision + recall) > 0
else 0.0
)
return {
"precision": precision,
"recall": recall,
"f1": f1,
"num_hits": len(hits_set),
"num_oracles": len(oracles_set),
"num_correct": len(intersection),
}
def load_method_results(method_dirs: List[str], method_name: str) -> Dict:
"""Load and aggregate results from multiple runs of the same method."""
all_f1_scores = []
all_token_usage = []
all_tool_calls = []
successful_instances = set()
print(f"\nLoading {method_name} method data from {len(method_dirs)} runs...")
for run_idx, run_dir in enumerate(method_dirs):
print(f" Processing run {run_idx + 1}: {run_dir}")
if not os.path.exists(run_dir):
print(f" Warning: Directory {run_dir} does not exist")
continue
run_success_count = 0
run_f1_scores = []
run_tokens = []
run_tools = []
for item in os.listdir(run_dir):
instance_dir = os.path.join(run_dir, item)
result_file = os.path.join(instance_dir, "result.json")
if os.path.isdir(instance_dir) and os.path.exists(result_file):
try:
with open(result_file, "r") as f:
data = json.load(f)
# Calculate F1-score
hits = data.get("hits", [])
oracles = data.get("oracles", [])
metrics = calculate_metrics(hits, oracles)
# Extract other metrics
tokens = data.get("token_usage", {}).get("total_tokens", 0)
tools = data.get("tool_stats", {}).get("total_tool_calls", 0)
# Store data
run_f1_scores.append(metrics["f1"])
run_tokens.append(tokens)
run_tools.append(tools)
successful_instances.add(item)
run_success_count += 1
except Exception as e:
print(f" Warning: Failed to load {result_file}: {e}")
continue
print(f" Loaded {run_success_count} successful instances")
# Add this run's data to overall collection
all_f1_scores.extend(run_f1_scores)
all_token_usage.extend(run_tokens)
all_tool_calls.extend(run_tools)
# Calculate aggregated statistics
results = {
"method_name": method_name,
"total_runs": len(method_dirs),
"successful_instances": len(successful_instances),
"avg_f1": np.mean(all_f1_scores) if all_f1_scores else 0,
"std_f1": np.std(all_f1_scores) if all_f1_scores else 0,
"avg_tokens": np.mean(all_token_usage) if all_token_usage else 0,
"std_tokens": np.std(all_token_usage) if all_token_usage else 0,
"avg_tools": np.mean(all_tool_calls) if all_tool_calls else 0,
"std_tools": np.std(all_tool_calls) if all_tool_calls else 0,
}
print(f" Aggregated results:")
print(f" Avg F1-Score: {results['avg_f1']:.3f} ± {results['std_f1']:.3f}")
print(f" Avg Tokens: {results['avg_tokens']:.0f} ± {results['std_tokens']:.0f}")
print(
f" Avg Tool Calls: {results['avg_tools']:.1f} ± {results['std_tools']:.1f}"
)
return results
def create_efficiency_chart(both_results: Dict, grep_results: Dict):
"""Create the efficiency comparison chart using Seaborn."""
# Set the aesthetic style
sns.set_style("whitegrid")
sns.set_palette("husl")
# Prepare data for plotting
data = {
"Method": [
"With claude-context MCP",
"Baseline",
"With claude-context MCP",
"Baseline",
],
"Metric": ["Token Usage", "Token Usage", "Tool Calls", "Tool Calls"],
"Value": [
both_results["avg_tokens"] / 1000, # Convert to thousands
grep_results["avg_tokens"] / 1000,
both_results["avg_tools"],
grep_results["avg_tools"],
],
}
df = pd.DataFrame(data)
# Create figure with custom styling
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))
# Custom color palette
colors = ["#3498db", "#e74c3c"] # Modern blue and red
# Token Usage subplot
token_data = df[df["Metric"] == "Token Usage"]
sns.barplot(
data=token_data,
x="Method",
y="Value",
ax=ax1,
palette=colors,
alpha=0.8,
edgecolor="white",
linewidth=2,
)
ax1.set_title("Token Usage", fontsize=18, fontweight="bold", pad=20)
ax1.set_ylabel("Average Tokens (K)", fontsize=14, fontweight="bold")
ax1.set_xlabel("")
ax1.tick_params(axis="x", labelsize=12)
ax1.tick_params(axis="y", labelsize=12)
# Set y-axis range with some padding
ax1.set_ylim(0, max(token_data["Value"]) * 1.15)
# Add value labels for token usage
token_values = [
both_results["avg_tokens"] / 1000,
grep_results["avg_tokens"] / 1000,
]
for i, val in enumerate(token_values):
ax1.text(
i,
val + 2,
f"{val:.1f}K",
ha="center",
va="bottom",
fontweight="bold",
fontsize=13,
color=colors[i],
)
# Add improvement annotation for tokens
token_reduction = (
(grep_results["avg_tokens"] - both_results["avg_tokens"])
/ grep_results["avg_tokens"]
* 100
)
mid_height = max(token_values) * 0.8
ax1.annotate(
f"-{token_reduction:.1f}%",
xy=(0.5, mid_height),
xycoords="data",
ha="center",
va="center",
fontsize=16,
fontweight="bold",
bbox=dict(
boxstyle="round,pad=0.5",
facecolor="#2ecc71",
alpha=0.8,
edgecolor="white",
linewidth=2,
),
color="white",
)
# Tool Calls subplot
tool_data = df[df["Metric"] == "Tool Calls"]
sns.barplot(
data=tool_data,
x="Method",
y="Value",
ax=ax2,
palette=colors,
alpha=0.8,
edgecolor="white",
linewidth=2,
)
ax2.set_title("Tool Calls", fontsize=18, fontweight="bold", pad=20)
ax2.set_ylabel("Average Number of Calls", fontsize=14, fontweight="bold")
ax2.set_xlabel("")
ax2.tick_params(axis="x", labelsize=12)
ax2.tick_params(axis="y", labelsize=12)
# Set y-axis range with some padding
ax2.set_ylim(0, max(tool_data["Value"]) * 1.15)
# Add value labels for tool calls
tool_values = [both_results["avg_tools"], grep_results["avg_tools"]]
for i, val in enumerate(tool_values):
ax2.text(
i,
val + 0.2,
f"{val:.1f}",
ha="center",
va="bottom",
fontweight="bold",
fontsize=13,
color=colors[i],
)
# Add improvement annotation for tool calls
tool_reduction = (
(grep_results["avg_tools"] - both_results["avg_tools"])
/ grep_results["avg_tools"]
* 100
)
mid_height = max(tool_values) * 0.8
ax2.annotate(
f"-{tool_reduction:.1f}%",
xy=(0.5, mid_height),
xycoords="data",
ha="center",
va="center",
fontsize=16,
fontweight="bold",
bbox=dict(
boxstyle="round,pad=0.5",
facecolor="#2ecc71",
alpha=0.8,
edgecolor="white",
linewidth=2,
),
color="white",
)
# Keep x-axis labels horizontal and add grid
for ax in [ax1, ax2]:
ax.tick_params(axis="x", rotation=0)
ax.grid(True, alpha=0.3)
# Adjust layout
plt.tight_layout()
# Save with high quality
output_file = "mcp_efficiency_analysis_chart.png"
plt.savefig(
output_file, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none"
)
plt.show()
print(f"\nChart saved as: {output_file}")
# Print summary
print(f"\n{'='*80}")
print(f"MCP EFFICIENCY ANALYSIS SUMMARY")
print(f"{'='*80}")
print(f"Method Comparison:")
print(f" Both (MCP) vs Grep (Baseline)")
print(
f" Runs per method: {both_results['total_runs']} vs {grep_results['total_runs']}"
)
print(f"\nF1-Score Comparison:")
print(f" Both Method: {both_results['avg_f1']:.3f} ± {both_results['std_f1']:.3f}")
print(f" Grep Method: {grep_results['avg_f1']:.3f} ± {grep_results['std_f1']:.3f}")
f1_change = (
(
(both_results["avg_f1"] - grep_results["avg_f1"])
/ grep_results["avg_f1"]
* 100
)
if grep_results["avg_f1"] > 0
else 0
)
print(f" F1-Score change: {f1_change:+.1f}%")
print(f"\nEfficiency Improvements:")
print(
f" Token usage reduction: {token_reduction:.1f}% (from {grep_results['avg_tokens']:.0f} to {both_results['avg_tokens']:.0f})"
)
print(
f" Tool calls reduction: {tool_reduction:.1f}% (from {grep_results['avg_tools']:.1f} to {both_results['avg_tools']:.1f})"
)
print(
f" Average token savings: {grep_results['avg_tokens'] - both_results['avg_tokens']:.0f} tokens per instance"
)
def main():
"""Main function to analyze and plot MCP efficiency."""
print("MCP Efficiency Analysis - Loading Data")
print("=" * 60)
# Define directories for each method
both_dirs = [
"retrieval_results_both",
"retrieval_results_both2",
"retrieval_results_both3",
]
grep_dirs = [
"retrieval_results_grep",
"retrieval_results_grep2",
"retrieval_results_grep3",
]
# Load and analyze results
both_results = load_method_results(both_dirs, "Both (with claude-context MCP)")
grep_results = load_method_results(grep_dirs, "Grep (baseline)")
# Create the efficiency chart
create_efficiency_chart(both_results, grep_results)
if __name__ == "__main__":
main()

62
evaluation/client.py Normal file
View File

@@ -0,0 +1,62 @@
import asyncio
from langgraph.prebuilt import create_react_agent
from utils.format import (
extract_conversation_summary,
extract_file_paths_from_edits,
calculate_total_tokens,
)
class Evaluator:
"""Evaluator class for running LLM queries with MCP tools"""
def __init__(self, llm_model, tools):
"""
Initialize the Evaluator
Args:
llm_model: LangChain LLM model instance (required)
tools: List of tools to use (required)
"""
self.llm_model = llm_model
self.tools = tools
self.agent = create_react_agent(self.llm_model, self.tools)
# Setup event loop for sync usage
try:
self.loop = asyncio.get_event_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
async def async_run(self, query, codebase_path=None):
"""Internal async method to run the query"""
response = await self.agent.ainvoke(
{"messages": [{"role": "user", "content": query}]},
config={"recursion_limit": 150},
)
# Extract data without printing
conversation_summary, tool_stats = extract_conversation_summary(response)
token_usage = calculate_total_tokens(response)
if codebase_path:
file_paths = extract_file_paths_from_edits(response, codebase_path)
else:
file_paths = []
return conversation_summary, token_usage, file_paths, tool_stats
def run(self, query: str, codebase_path=None):
"""
Run a query synchronously
Args:
query (str): The query to execute
codebase_path (str): Path to the codebase for relative path conversion
Returns:
tuple: (response, conversation_summary, token_usage, file_paths)
"""
return asyncio.run(self.async_run(query, codebase_path))

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
"""
Generate swe_verified_15min1h_2files_instances.json from the subset analysis
"""
import json
import re
from datasets import load_dataset
def parse_patch_files(patch_content):
"""Parse patch content to extract the number of modified files"""
if not patch_content:
return []
file_pattern = r'^diff --git a/(.*?) b/(.*?)$'
files = []
for line in patch_content.split('\n'):
match = re.match(file_pattern, line)
if match:
file_path = match.group(1)
files.append(file_path)
return files
def main():
print("Loading SWE-bench_Verified dataset...")
dataset = load_dataset("princeton-nlp/SWE-bench_Verified")
instances = list(dataset['test'])
print("Filtering instances for: 15min-1hour difficulty + 2 patch files...")
# Filter for the specific subset
subset_instances = []
for instance in instances:
difficulty = instance.get('difficulty', 'Unknown')
# Parse main patch to count files
patch_content = instance.get('patch', '')
patch_files = parse_patch_files(patch_content)
oracle_count = len(patch_files)
# Check if it matches our criteria
if difficulty == '15 min - 1 hour' and oracle_count == 2:
subset_instances.append(instance)
print(f"Found {len(subset_instances)} instances matching criteria")
# Create the JSON structure that _prepare_instances expects
output_data = {
"metadata": {
"description": "SWE-bench_Verified subset: 15min-1hour difficulty with 2 patch files",
"source_dataset": "princeton-nlp/SWE-bench_Verified",
"extraction_date": "2024",
"filter_criteria": {
"difficulty": "15 min - 1 hour",
"patch_files_count": 2
},
"total_instances": len(subset_instances),
"statistics": {
"total_instances_in_original": 500,
"subset_count": len(subset_instances),
"percentage_of_original": round((len(subset_instances) / 500) * 100, 1)
}
},
"instances": subset_instances
}
# Save to JSON file
output_file = "swe_verified_15min1h_2files_instances.json"
with open(output_file, 'w') as f:
json.dump(output_data, f, indent=2)
print(f"Generated {output_file} with {len(subset_instances)} instances")
# Verify the structure
print("\nVerifying JSON structure...")
with open(output_file, 'r') as f:
loaded_data = json.load(f)
print(f"✓ Contains 'instances' key: {'instances' in loaded_data}")
print(f"✓ Contains 'metadata' key: {'metadata' in loaded_data}")
print(f"✓ Number of instances: {len(loaded_data['instances'])}")
print(f"✓ First instance has required fields:")
if loaded_data['instances']:
first_instance = loaded_data['instances'][0]
required_fields = ['instance_id', 'repo', 'base_commit', 'problem_statement']
for field in required_fields:
has_field = field in first_instance
print(f" - {field}: {'' if has_field else ''}")
print(f"\nFile successfully generated: {output_file}")
print("This file can be used with BaseRetrieval._prepare_instances()")
if __name__ == "__main__":
main()

24
evaluation/pyproject.toml Normal file
View File

@@ -0,0 +1,24 @@
[project]
name = "evaluation"
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"langchain-mcp-adapters>=0.1.9",
"langgraph>=0.6.4",
"mcp>=1.12.4",
"langchain>=0.1.0",
"langchain-core>=0.1.0",
"langchain-anthropic>=0.1.0",
"langchain-openai>=0.3.29",
"langchain-ollama>=0.3.6",
"datasets>=4.0.0",
"gitpython>=3.1.45",
"matplotlib>=3.10.5",
"seaborn>=0.13.2",
"pandas>=2.3.1",
"numpy>=2.2.6",
"plotly>=6.3.0",
"kaleido>=1.0.0",
]

View File

View File

@@ -0,0 +1,210 @@
import json
import os
import traceback
from pathlib import Path
from tqdm.auto import tqdm
from typing import List, Dict, Any, Tuple
from datasets import load_from_disk, load_dataset
from utils.file_management import get_remaining_instances
from utils.file_management import ContextManager, clone_repo
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
class BaseRetrieval:
def __init__(
self, *, dataset_name_or_path, splits, output_dir, max_instances=None, **kwargs
):
self.dataset_name_or_path = dataset_name_or_path
self.splits = splits
self.output_dir = output_dir
self.max_instances = max_instances
self.instances = self._prepare_instances()
self.prompt = """The codebase is at {repo_path}.
Issue:
<issue>
{issue}
</issue>
Your task is to identify and edit the files that need to be modified to resolve the issue.
Focus on making the necessary changes to completely address the problem.
Use the available tools step by step to accomplish this goal. The primary objective is to edit the existing code files. No validation or testing is required.
"""
def _prepare_instances(self) -> List[Dict]:
if Path(self.dataset_name_or_path).exists():
# Check if it's a JSON file
if self.dataset_name_or_path.endswith(".json"):
with open(self.dataset_name_or_path, "r") as f:
data = json.load(f)
# If it's our custom JSON format with instances data
if "instances" in data:
logger.info(
f"Loaded {len(data['instances'])} instances from JSON file"
)
if "metadata" in data and "statistics" in data["metadata"]:
logger.info(f"Statistics: {data['metadata']['statistics']}")
# Create a simple dict that mimics HuggingFace dataset structure
dataset = {"test": data["instances"]}
elif "test" in data:
dataset = {"test": data["test"]}
else:
# Assume the JSON file itself contains the instances
dataset = {"test": data if isinstance(data, list) else [data]}
dataset_name = os.path.basename(self.dataset_name_or_path).replace(
".json", ""
)
else:
dataset = load_from_disk(self.dataset_name_or_path)
dataset_name = os.path.basename(self.dataset_name_or_path)
else:
dataset = load_dataset(self.dataset_name_or_path)
dataset_name = self.dataset_name_or_path.replace("/", "__")
instances = []
from datasets import DatasetDict
if isinstance(dataset, DatasetDict):
available_splits = set(dataset.keys())
if set(self.splits) - available_splits != set():
missing_splits = set(self.splits) - available_splits
logger.warning(f"Unknown splits {missing_splits}")
for split in self.splits:
logger.info(f"Loading split '{split}'")
from datasets import DatasetDict, IterableDatasetDict
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
split_instances = list(dataset[split])
elif isinstance(dataset, dict) and split in dataset:
# Handle our custom JSON format
split_instances = dataset[split]
else:
split_instances = list(dataset)
instances.extend(split_instances)
logger.info(f"Loaded {len(split_instances)} instances from split '{split}'")
output_file = Path(self.output_dir) / f"{dataset_name}__retrieval.jsonl"
output_file.parent.mkdir(parents=True, exist_ok=True)
# Check for both JSONL format (for legacy compatibility) and directory structure format
remaining_instances, processed_count = self._filter_existing_instances(
instances, output_file
)
if not remaining_instances:
logger.info("All instances already processed")
return []
# Apply max_instances limit if specified
if self.max_instances is not None and self.max_instances > 0:
# Check if we've already processed enough instances
if processed_count >= self.max_instances:
logger.info(
f"Already processed {processed_count} instances, which meets or exceeds max_instances={self.max_instances}. No more instances to process."
)
return []
# Calculate how many more instances we need to process
remaining_needed = self.max_instances - processed_count
if len(remaining_instances) > remaining_needed:
logger.info(
f"Limiting to {remaining_needed} more instances (processed: {processed_count}, target: {self.max_instances}, remaining: {len(remaining_instances)})"
)
remaining_instances = remaining_instances[:remaining_needed]
return remaining_instances
def _filter_existing_instances(
self, instances: List[Dict], output_file: Path
) -> Tuple[List[Dict], int]:
"""
Filter instances to exclude those that have already been processed.
This method supports both output formats:
1. JSONL format (legacy): results saved to a single JSONL file
2. Directory format: results saved to individual directories with result.json files
Args:
instances: List of instances to filter
output_file: Path to the JSONL output file (used for legacy format detection)
Returns:
Tuple of (remaining_instances, processed_count)
"""
# First check JSONL format for backward compatibility
if output_file.exists():
# JSONL format already handled by get_remaining_instances
remaining_instances = get_remaining_instances(instances, output_file)
processed_count = len(instances) - len(remaining_instances)
return remaining_instances, processed_count
else:
# Check directory structure format
processed_instance_ids = set()
# Check if output directory exists and has subdirectories with result.json
if os.path.exists(self.output_dir):
for item in os.listdir(self.output_dir):
instance_dir = os.path.join(self.output_dir, item)
result_file = os.path.join(instance_dir, "result.json")
if os.path.isdir(instance_dir) and os.path.exists(result_file):
processed_instance_ids.add(item)
processed_count = len(processed_instance_ids)
if processed_count > 0:
logger.info(
f"Found {processed_count} existing instances in directory format. Will skip them."
)
# Filter out already processed instances
remaining_instances = [
instance
for instance in instances
if instance["instance_id"] not in processed_instance_ids
]
return remaining_instances, processed_count
def build_index(self, repo_path: str) -> Any:
raise NotImplementedError("Subclasses must implement this method")
def search(self, repo_path: str, issue: str, k: int = 20) -> List[Dict[str, Any]]:
raise NotImplementedError("Subclasses must implement this method")
def run(self, root_dir: str, token: str = "git") -> None:
for instance in tqdm(self.instances, desc="Running retrieval"):
instance_id = instance["instance_id"]
repo = instance["repo"]
commit = instance["base_commit"]
issue = instance["problem_statement"]
try:
repo_dir = clone_repo(repo, root_dir, token)
with ContextManager(str(repo_dir), commit):
logger.info(f"Building index for {instance_id}")
self.build_index(str(repo_dir))
logger.info(f"Searching for {instance_id}")
hits = self.search(repo_dir, issue, k=20)
result = {"instance_id": instance_id, "hits": hits}
with open(self.output_file, "a") as f:
f.write(json.dumps(result) + "\n")
logger.info(
f"Retrieval completed. Results saved to {self.output_file}"
)
except Exception as e:
logger.error(f"Error processing {instance_id}: {e}")
logger.error(traceback.format_exc())
continue

View File

@@ -0,0 +1,399 @@
import traceback
from typing import List, Dict, Any
import asyncio
from contextlib import asynccontextmanager
from retrieval.base import BaseRetrieval
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
import os
import logging
import sys
import time
from client import Evaluator
from utils.llm_factory import llm_factory
from utils.constant import project_path, evaluation_path
from utils.format import extract_oracle_files_from_patch, create_unified_diff_file
import json
import traceback
from tqdm.auto import tqdm
from typing import List, Dict, Any
from utils.file_management import ContextManager, clone_repo
logger = logging.getLogger(__name__)
class CustomRetrieval(BaseRetrieval):
def __init__(
self,
llm_type: str,
llm_model: str,
retrieval_types: List[str],
*,
dataset_name_or_path,
splits,
output_dir,
**kwargs,
):
"""
Initialize CustomRetrieval with specified retrieval types.
Args:
llm_type: Type of LLM to use
llm_model: LLM model name
retrieval_types: List containing "cc", "grep", or both
dataset_name_or_path: Dataset path
splits: Dataset splits
output_dir: Output directory
**kwargs: Additional arguments
"""
super().__init__(
dataset_name_or_path=dataset_name_or_path,
splits=splits,
output_dir=output_dir,
**kwargs,
)
# Validate retrieval types
valid_types = {"cc", "grep"}
if not isinstance(retrieval_types, list):
raise ValueError("retrieval_types must be a list")
if not all(rt in valid_types for rt in retrieval_types):
raise ValueError(
f"retrieval_types must contain only 'cc' and/or 'grep', got: {retrieval_types}"
)
if not retrieval_types:
raise ValueError("retrieval_types cannot be empty")
self.retrieval_types = retrieval_types
self.llm_model = llm_factory(llm_type, llm_model)
self.mcp_client = self._create_mcp_client()
def _create_mcp_client(self) -> MultiServerMCPClient:
"""Create MCP client based on retrieval types"""
servers = {
"filesystem": {
"command": sys.executable,
"args": [str(evaluation_path / "servers/read_server.py"),],
"transport": "stdio",
},
"edit": {
"command": sys.executable,
"args": [str(evaluation_path / "servers/edit_server.py"),],
"transport": "stdio",
},
}
# Add CC server if needed
if "cc" in self.retrieval_types:
servers["claude-context"] = {
# "command": "node",
# "args": [str(project_path / "packages/mcp/dist/index.js")], # For development environment
"command": "npx",
"args": ["-y", "@zilliz/claude-context-mcp@0.1.0"], # For reproduction environment
"env": {
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
"MILVUS_ADDRESS": os.getenv("MILVUS_ADDRESS"),
"EMBEDDING_BATCH_SIZE": os.getenv("EMBEDDING_BATCH_SIZE", "100"),
},
"transport": "stdio",
}
# Add Grep server if needed
if "grep" in self.retrieval_types:
servers["grep"] = {
"command": sys.executable,
"args": [str(evaluation_path / "servers/grep_server.py"),],
"transport": "stdio",
}
return MultiServerMCPClient(servers)
@asynccontextmanager
async def mcp_sessions_context(self):
"""Context manager for MCP sessions and tools loading"""
# Build session context based on retrieval types
session_names = ["filesystem", "edit"]
# Add CC session if needed
if "cc" in self.retrieval_types:
session_names.append("claude-context")
# Add Grep session if needed
if "grep" in self.retrieval_types:
session_names.append("grep")
# Create the appropriate context manager based on which sessions we need
if len(session_names) == 2: # filesystem + edit
async with self.mcp_client.session(
"filesystem"
) as fs_session, self.mcp_client.session("edit") as edit_session:
sessions = {
"filesystem": fs_session,
"edit": edit_session,
}
yield await self._load_tools_from_sessions(sessions)
elif len(session_names) == 3:
if "claude-context" in session_names:
async with self.mcp_client.session(
"filesystem"
) as fs_session, self.mcp_client.session(
"edit"
) as edit_session, self.mcp_client.session(
"claude-context"
) as cc_session:
sessions = {
"filesystem": fs_session,
"edit": edit_session,
"claude-context": cc_session,
}
yield await self._load_tools_from_sessions(sessions)
else: # grep
async with self.mcp_client.session(
"filesystem"
) as fs_session, self.mcp_client.session(
"edit"
) as edit_session, self.mcp_client.session(
"grep"
) as grep_session:
sessions = {
"filesystem": fs_session,
"edit": edit_session,
"grep": grep_session,
}
yield await self._load_tools_from_sessions(sessions)
else: # all 4 sessions
async with self.mcp_client.session(
"filesystem"
) as fs_session, self.mcp_client.session(
"edit"
) as edit_session, self.mcp_client.session(
"claude-context"
) as cc_session, self.mcp_client.session(
"grep"
) as grep_session:
sessions = {
"filesystem": fs_session,
"edit": edit_session,
"claude-context": cc_session,
"grep": grep_session,
}
yield await self._load_tools_from_sessions(sessions)
async def _load_tools_from_sessions(self, sessions: Dict):
"""Load tools from the provided sessions"""
fs_tools = await load_mcp_tools(sessions["filesystem"])
edit_tools = await load_mcp_tools(sessions["edit"])
# Get basic tools
edit_tool = next((tool for tool in edit_tools if tool.name == "edit"), None,)
# Start with filesystem tools
search_tools = [
tool
for tool in fs_tools
if tool.name in ["read_file", "list_directory", "directory_tree"]
]
# Add edit tool
if edit_tool:
search_tools.append(edit_tool)
# Initialize CC-specific tools
cc_tools = {
"index_tool": None,
"indexing_status_tool": None,
"clear_index_tool": None,
"search_code_tool": None,
}
# Load CC tools if needed
if "cc" in self.retrieval_types and "claude-context" in sessions:
cc_tool_list = await load_mcp_tools(sessions["claude-context"])
cc_tools["index_tool"] = next(
(tool for tool in cc_tool_list if tool.name == "index_codebase"), None
)
cc_tools["indexing_status_tool"] = next(
(tool for tool in cc_tool_list if tool.name == "get_indexing_status"),
None,
)
cc_tools["clear_index_tool"] = next(
(tool for tool in cc_tool_list if tool.name == "clear_index"), None
)
cc_tools["search_code_tool"] = next(
(tool for tool in cc_tool_list if tool.name == "search_code"), None
)
# Add search code tool to search tools
if cc_tools["search_code_tool"]:
search_tools.append(cc_tools["search_code_tool"])
# Load Grep tools if needed
if "grep" in self.retrieval_types and "grep" in sessions:
grep_tools = await load_mcp_tools(sessions["grep"])
# Add grep tool (typically the first one is search_text)
if grep_tools:
search_tools.append(grep_tools[0])
# Return tools as a dictionary for easy access
return {
"search_tools": search_tools,
**cc_tools,
}
def build_index(self, repo_path: str) -> Any:
asyncio.run(self.async_build_index(repo_path))
async def async_build_index(self, repo_path: str) -> Any:
"""Build index only if CC is enabled"""
if "cc" not in self.retrieval_types:
return
async with self.mcp_sessions_context() as tools:
index_tool = tools["index_tool"]
indexing_status_tool = tools["indexing_status_tool"]
clear_index_tool = tools["clear_index_tool"]
if not index_tool or not indexing_status_tool or not clear_index_tool:
raise RuntimeError("CC tools not found in MCP sessions")
try:
await index_tool.ainvoke(
{
"path": repo_path,
"force": False,
"splitter": "ast",
"customExtensions": [],
"ignorePatterns": [],
}
)
while True:
status = await indexing_status_tool.ainvoke({"path": repo_path,})
if "fully indexed and ready for search" in status:
break
time.sleep(2)
# For strong consistency, wait for a while before searching
time.sleep(5)
except Exception as e:
logger.error(f"Error building index: {e}")
logger.error(traceback.format_exc())
await clear_index_tool.ainvoke(
{"path": repo_path,}
)
# For strong consistency, wait for a while before searching
time.sleep(5)
logger.info(f"Cleared index for {repo_path}")
raise e
def search(self, repo_path: str, issue: str, k: int = 20) -> tuple:
return asyncio.run(self.async_search(repo_path, issue, k))
async def async_search(self, repo_path: str, issue: str, k: int = 20) -> tuple:
async with self.mcp_sessions_context() as tools:
search_tools = tools["search_tools"]
evaluator = Evaluator(self.llm_model, search_tools)
query = self.prompt.format(repo_path=repo_path, issue=issue)
try:
(
conversation_summary,
token_usage,
file_paths,
tool_stats,
) = await evaluator.async_run(query, repo_path)
finally:
# Clear index if CC is enabled
if "cc" in self.retrieval_types:
clear_index_tool = tools["clear_index_tool"]
if clear_index_tool:
try:
await clear_index_tool.ainvoke(
{"path": repo_path,}
)
# For strong consistency, wait for a while before searching
time.sleep(3)
logger.info(f"Cleared index for {repo_path}")
except Exception as clear_error:
logger.warning(
f"Failed to clear index for {repo_path}: {clear_error}"
)
return file_paths, token_usage, conversation_summary, tool_stats
def run(self, root_dir: str, token: str = "git") -> None:
asyncio.run(self.async_run(root_dir, token))
async def async_run(self, root_dir: str, token: str = "git") -> None:
for instance in tqdm(self.instances, desc="Running retrieval"):
instance_id = instance["instance_id"]
repo = instance["repo"]
commit = instance["base_commit"]
issue = instance["problem_statement"]
# Create instance directory
instance_dir = os.path.join(self.output_dir, instance_id)
os.makedirs(instance_dir, exist_ok=True)
try:
repo_dir = clone_repo(repo, root_dir, token)
with ContextManager(str(repo_dir), commit):
logger.info(f"Building index for {instance_id}")
await self.async_build_index(str(repo_dir))
logger.info(f"Searching for {instance_id}")
(
hits,
token_usage,
conversation_summary,
tool_stats,
) = await self.async_search(repo_dir, issue, k=20)
# Extract oracle files from patch
oracles = extract_oracle_files_from_patch(instance.get("patch", ""))
# Prepare result data
result = {
"instance_id": instance_id,
"hits": hits,
"oracles": oracles,
"token_usage": token_usage,
"tool_stats": tool_stats,
"retrieval_types": self.retrieval_types, # Add info about which retrieval types were used
}
# Save result and token info to JSON file
result_file = os.path.join(instance_dir, "result.json")
with open(result_file, "w") as f:
json.dump(result, f, indent=2)
# Save conversation log
log_file = os.path.join(instance_dir, "conversation.log")
with open(log_file, "w") as f:
f.write(conversation_summary)
# Create unified diff file from conversation log
try:
create_unified_diff_file(instance_dir, conversation_summary)
logger.info(f"Created unified diff file for {instance_id}")
except Exception as e:
logger.warning(
f"Failed to create unified diff file for {instance_id}: {e}"
)
logger.info(
f"Retrieval completed for {instance_id}. Results saved to {instance_dir}"
)
except Exception as e:
# Save error stack trace to error.log
error_file = os.path.join(instance_dir, "error.log")
with open(error_file, "w") as f:
f.write(f"Error processing {instance_id}: {e}\n\n")
f.write(traceback.format_exc())
logger.error(f"Error processing {instance_id}: {e}")
logger.error(traceback.format_exc())
continue

View File

@@ -0,0 +1,123 @@
import os
from argparse import ArgumentParser
from typing import List, Optional
from retrieval.custom import CustomRetrieval
from utils.constant import evaluation_path, project_path
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
def main(
dataset_name_or_path: str,
output_dir: str,
retrieval_types: List[str],
llm_type: str = "openai",
llm_model: Optional[str] = None,
splits: List[str] = ["test"],
root_dir: str = str(evaluation_path / "repos"),
max_instances: Optional[int] = 5,
):
"""
Main function to run custom retrieval.
Args:
dataset_name_or_path: Dataset path or name
output_dir: Output directory for results
retrieval_types: List of retrieval types to use ('cc', 'grep', or both)
llm_type: Type of LLM to use
llm_model: LLM model name
splits: Dataset splits to process
root_dir: Root directory for repositories
max_instances: Maximum number of instances to process
"""
logger.info(f"Starting custom retrieval with types: {retrieval_types}")
retrieval = CustomRetrieval(
dataset_name_or_path=dataset_name_or_path,
splits=splits,
output_dir=output_dir,
retrieval_types=retrieval_types,
llm_type=llm_type,
llm_model=llm_model,
max_instances=max_instances,
)
retrieval.run(root_dir, token=os.environ.get("GITHUB_TOKEN", "git"))
def parse_retrieval_types(value: str) -> List[str]:
"""Parse comma-separated retrieval types string into list"""
types = [t.strip().lower() for t in value.split(",")]
valid_types = {"cc", "grep"}
for t in types:
if t not in valid_types:
raise ValueError(
f"Invalid retrieval type '{t}'. Must be one of: {valid_types}"
)
return types
if __name__ == "__main__":
parser = ArgumentParser(
description="Custom Retrieval for SWE-bench with flexible retrieval types"
)
parser.add_argument(
"--dataset_name_or_path",
type=str,
# default="SWE-bench/SWE-bench_Lite",
default="swe_verified_15min1h_2files_instances.json",
help="Dataset name or path",
)
parser.add_argument(
"--output_dir",
type=str,
default=str(evaluation_path / "retrieval_results_custom"),
help="Output directory",
)
parser.add_argument(
"--retrieval_types",
type=parse_retrieval_types,
default="cc,grep",
help="Comma-separated list of retrieval types to use. Options: 'cc', 'grep', or 'cc,grep' (default: 'cc,grep')",
)
parser.add_argument(
"--llm_type",
type=str,
choices=["openai", "ollama", "moonshot"],
# default="moonshot",
default="openai",
# default="anthropic",
help="LLM type",
)
parser.add_argument(
"--llm_model",
type=str,
# default="kimi-k2-0711-preview",
default="gpt-4o-mini",
# default="claude-sonnet-4-20250514",
help="LLM model name, e.g. gpt-4o-mini",
)
parser.add_argument(
"--splits", nargs="+", default=["test"], help="Dataset splits to process"
)
parser.add_argument(
"--root_dir",
type=str,
default=str(evaluation_path / "repos"),
help="Temporary directory for repositories",
)
parser.add_argument(
"--max_instances",
type=int,
default=5,
help="Maximum number of instances to process (default: 5, set to -1 for all)",
)
args = parser.parse_args()
main(**vars(args))

View File

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""
An edit server using MCP (Model Context Protocol).
This server provides file editing functionality for modifying files.
"""
import os
from mcp.server.fastmcp import FastMCP
# Create the MCP server
mcp = FastMCP("Edit Server")
@mcp.tool()
def edit(file_path: str, old_string: str, new_string: str) -> str:
"""Edits the specified file with the given modifications.
This tool marks files that need to be edited with the specified changes.
Args:
file_path: The absolute path to the file to modify.
old_string: The exact literal text to replace. Must uniquely identify the single
instance to change. Should include at least 3 lines of context before
and after the target text, matching whitespace and indentation precisely.
If old_string is empty, the tool attempts to create a new file at
file_path with new_string as content.
new_string: The exact literal text to replace old_string with.
Returns:
A string indicating the file has been successfully modified.
"""
# Mock the edit operation
return f"Successfully modified file: {file_path}"
if __name__ == "__main__":
# Run the server with stdio transport
mcp.run(transport="stdio")

View File

@@ -0,0 +1,217 @@
#!/usr/bin/env python3
"""
A grep server using MCP (Model Context Protocol).
This server provides grep functionality to search for regular expression patterns within files.
Implementation logic inspired by Gemini CLI's grep.ts:
https://github.com/google-gemini/gemini-cli/blob/main/packages/core/src/tools/grep.ts
Adapted from TypeScript to Python implementation with similar fallback strategy.
"""
import os
import subprocess
from typing import Dict, Any, Optional
from mcp.server.fastmcp import FastMCP
# Create the MCP server
mcp = FastMCP("Grep Server")
def is_git_repository(path: str) -> bool:
"""Check if the given path is inside a git repository."""
try:
result = subprocess.run(
["git", "rev-parse", "--git-dir"],
cwd=path,
capture_output=True,
text=True,
timeout=5,
)
return result.returncode == 0
except (subprocess.SubprocessError, subprocess.TimeoutExpired, FileNotFoundError):
return False
@mcp.tool()
def search_text(
pattern: str, path: Optional[str] = None, include: Optional[str] = None
) -> Dict[str, Any]:
"""Searches for a regular expression pattern within the content of files in a specified directory (or current working directory). Can filter files by a glob pattern. Returns the lines containing matches, along with their file paths and line numbers.
Args:
pattern: The regular expression (regex) pattern to search for within file contents (e.g., 'function\\s+myFunction', 'import\\s+\\{.*\\}\\s+from\\s+.*').
path: Optional: The absolute path to the directory to search within. If omitted, searches the current working directory.
include: Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).
Returns:
A dictionary containing search results with file paths, line numbers, and matching lines.
"""
# Use current working directory if no path specified
search_path = path if path else os.getcwd()
# Validate that the search path exists
if not os.path.exists(search_path):
return {"error": f"Path does not exist: {search_path}", "matches": []}
try:
# Check if we're in a git repository and try git grep first
if is_git_repository(search_path):
try:
# Build git grep command
git_cmd = ["git", "grep", "-n", "-E"]
# Add include pattern if specified (git grep uses different syntax)
if include:
git_cmd.extend(["--", include])
else:
git_cmd.append("--")
# Add pattern
git_cmd.insert(-1, pattern) # Insert pattern before the "--" separator
# Execute git grep command
result = subprocess.run(
git_cmd,
cwd=search_path,
capture_output=True,
text=True,
encoding="utf-8",
errors="ignore",
timeout=30,
)
# If git grep succeeds, use its output
if result.returncode == 0:
# Parse git grep output and return results
matches = []
if result.stdout:
for line in result.stdout.strip().split("\n"):
if ":" in line:
# Parse git grep output format: filepath:line_number:content
parts = line.split(":", 2)
if len(parts) >= 3:
file_path = parts[0]
try:
line_number = int(parts[1])
line_content = parts[2]
matches.append(
{
"file": os.path.join(
search_path, file_path
)
if not os.path.isabs(file_path)
else file_path,
"line_number": line_number,
"line_content": line_content,
"match": pattern,
}
)
except ValueError:
continue
return {
"pattern": pattern,
"search_path": search_path,
"total_matches": len(matches),
"matches": matches,
"command": " ".join(git_cmd),
"method": "git grep",
}
except (
subprocess.SubprocessError,
subprocess.TimeoutExpired,
FileNotFoundError,
):
# Git grep failed, fall back to regular grep
pass
# Fallback: Build regular grep command
cmd = [
"grep",
"-n",
"-r",
"-E",
] # -n for line numbers, -r for recursive, -E for extended regex
# Add include pattern if specified
if include:
cmd.extend(["--include", include])
# Add common exclusions
cmd.extend(
[
"--exclude-dir=.git",
"--exclude-dir=node_modules",
"--exclude-dir=__pycache__",
"--exclude-dir=.svn",
"--exclude-dir=.hg",
"--exclude-dir=venv",
"--exclude-dir=env",
"--exclude=*.pyc",
"--exclude=*.pyo",
"--exclude=*.so",
"--exclude=*.dll",
"--exclude=*.exe",
"--exclude=*.jpg",
"--exclude=*.jpeg",
"--exclude=*.png",
"--exclude=*.gif",
"--exclude=*.zip",
"--exclude=*.tar",
"--exclude=*.gz",
"--exclude=*.pdf",
"--exclude=*.wasm",
]
)
# Add pattern and search path
cmd.extend([pattern, search_path])
# Execute grep command
result = subprocess.run(
cmd, capture_output=True, text=True, encoding="utf-8", errors="ignore"
)
# Parse grep output
matches = []
if result.stdout:
for line in result.stdout.strip().split("\n"):
if ":" in line:
# Parse grep output format: filepath:line_number:content
parts = line.split(":", 2)
if len(parts) >= 3:
file_path = parts[0]
try:
line_number = int(parts[1])
line_content = parts[2]
matches.append(
{
"file": file_path,
"line_number": line_number,
"line_content": line_content,
"match": pattern, # grep already matched, so pattern is the match
}
)
except ValueError:
# Skip malformed lines
continue
return {
"pattern": pattern,
"search_path": search_path,
"total_matches": len(matches),
"matches": matches,
"command": " ".join(cmd), # Include the actual command for debugging
"method": "system grep",
}
except subprocess.SubprocessError as e:
return {"error": f"Grep command failed: {str(e)}", "matches": []}
except Exception as e:
return {"error": f"Unexpected error: {str(e)}", "matches": []}
if __name__ == "__main__":
# Run the server with stdio transport
mcp.run(transport="stdio")

View File

@@ -0,0 +1,272 @@
#!/usr/bin/env python3
"""
A read_file server using MCP (Model Context Protocol).
This server provides file reading functionality for text files.
Implementation logic inspired by Gemini CLI's read-file.ts:
https://github.com/google-gemini/gemini-cli/blob/main/packages/core/src/tools/read-file.ts
Adapted from TypeScript to Python implementation with text file handling.
"""
import os
from typing import Dict, Any, Optional
from mcp.server.fastmcp import FastMCP
# Create the MCP server
mcp = FastMCP("Read File Server")
@mcp.tool()
def read_file(
path: str, offset: Optional[int] = None, limit: Optional[int] = None
) -> Dict[str, Any]:
"""Reads the content of a text file at the specified path.
You can optionally specify an offset and limit to read only a portion of the file.
Args:
path: The absolute path to the file to read.
offset: Optional: The line number to start reading from (0-based).
limit: Optional: The maximum number of lines to read.
Returns:
A dictionary containing either:
- For text files: {"content": "file content as string", "type": "text", "total_lines": number}
- For errors: {"error": "error message"}
"""
try:
# Validate path is absolute
if not os.path.isabs(path):
return {"error": f"Path must be absolute: {path}"}
# Check if file exists
if not os.path.exists(path):
return {"error": f"File does not exist: {path}"}
# Check if it's actually a file
if os.path.isdir(path):
return {"error": f"Path is a directory, not a file: {path}"}
# Get file extension
_, ext = os.path.splitext(path.lower())
# Try to read as text file
try:
# Try to read as text with UTF-8 encoding
with open(path, "r", encoding="utf-8", errors="replace") as file:
lines = file.readlines()
total_lines = len(lines)
if offset is not None and limit is not None:
# Validate offset
if offset < 0:
return {"error": f"Offset must be non-negative: {offset}"}
if offset >= total_lines:
return {
"error": f"Offset {offset} is beyond file length {total_lines}"
}
# Calculate end position
start = offset
end = min(offset + limit, total_lines)
content = "".join(lines[start:end])
# Add truncation notice if needed
if end < total_lines:
content = (
f"[File content truncated: showing lines {start + 1}-{end} of {total_lines} total lines...]\n"
+ content
)
else:
content = "".join(lines)
return {
"content": content,
"type": "text",
"total_lines": total_lines,
"path": path,
}
except UnicodeDecodeError:
# If UTF-8 fails, try other common encodings
for encoding in ["latin-1", "cp1252", "iso-8859-1"]:
try:
with open(path, "r", encoding=encoding, errors="replace") as file:
lines = file.readlines()
total_lines = len(lines)
if offset is not None and limit is not None:
if offset < 0:
return {
"error": f"Offset must be non-negative: {offset}"
}
if offset >= total_lines:
return {
"error": f"Offset {offset} is beyond file length {total_lines}"
}
start = offset
end = min(offset + limit, total_lines)
content = "".join(lines[start:end])
if end < total_lines:
content = (
f"[File content truncated: showing lines {start + 1}-{end} of {total_lines} total lines...]\n"
+ content
)
else:
content = "".join(lines)
return {
"content": content,
"type": "text",
"total_lines": total_lines,
"path": path,
"encoding": encoding,
}
except UnicodeDecodeError:
continue
# If all encodings fail, treat as binary
return {"error": f"Cannot read file as text (encoding issues): {path}"}
except Exception as e:
return {"error": f"Unexpected error reading file: {str(e)}"}
@mcp.tool()
def list_directory(path: str) -> Dict[str, Any]:
"""Lists the contents of a directory.
Args:
path: The absolute path to the directory to list.
Returns:
A dictionary containing:
- For success: {"entries": [{"name": "...", "type": "file|directory", "size": number}], "path": "..."}
- For errors: {"error": "error message"}
"""
try:
# Validate path is absolute
if not os.path.isabs(path):
return {"error": f"Path must be absolute: {path}"}
# Check if directory exists
if not os.path.exists(path):
return {"error": f"Directory does not exist: {path}"}
# Check if it's actually a directory
if not os.path.isdir(path):
return {"error": f"Path is not a directory: {path}"}
entries = []
for item in os.listdir(path):
item_path = os.path.join(path, item)
try:
if os.path.isfile(item_path):
size = os.path.getsize(item_path)
entries.append({"name": item, "type": "file", "size": size})
elif os.path.isdir(item_path):
entries.append({"name": item, "type": "directory", "size": 0})
except (OSError, PermissionError):
# Skip items we can't access
continue
# Sort entries: directories first, then files, both alphabetically
entries.sort(key=lambda x: (x["type"] == "file", x["name"].lower()))
return {"entries": entries, "path": path, "total_count": len(entries)}
except PermissionError:
return {"error": f"Permission denied accessing directory: {path}"}
except Exception as e:
return {"error": f"Unexpected error listing directory: {str(e)}"}
@mcp.tool()
def directory_tree(path: str, max_depth: Optional[int] = 3) -> Dict[str, Any]:
"""Generates a tree structure of a directory.
Args:
path: The absolute path to the directory to generate tree for.
max_depth: Optional: Maximum depth to traverse (default: 3).
Returns:
A dictionary containing:
- For success: {"tree": "tree structure as string", "path": "..."}
- For errors: {"error": "error message"}
"""
try:
# Validate path is absolute
if not os.path.isabs(path):
return {"error": f"Path must be absolute: {path}"}
# Check if directory exists
if not os.path.exists(path):
return {"error": f"Directory does not exist: {path}"}
# Check if it's actually a directory
if not os.path.isdir(path):
return {"error": f"Path is not a directory: {path}"}
def build_tree(current_path: str, prefix: str = "", depth: int = 0) -> str:
if max_depth and depth >= max_depth:
return ""
tree_str = ""
try:
items = sorted(os.listdir(current_path))
for i, item in enumerate(items):
item_path = os.path.join(current_path, item)
is_last = i == len(items) - 1
# Skip hidden files and common ignore patterns
if item.startswith(".") and item not in [
".env",
".gitignore",
".gitattributes",
]:
continue
if item in [
"node_modules",
"__pycache__",
".git",
".svn",
".hg",
"venv",
"env",
]:
continue
try:
if os.path.isdir(item_path):
tree_str += (
f"{prefix}{'└── ' if is_last else '├── '}{item}/\n"
)
extension = " " if is_last else ""
tree_str += build_tree(
item_path, prefix + extension, depth + 1
)
else:
tree_str += (
f"{prefix}{'└── ' if is_last else '├── '}{item}\n"
)
except (OSError, PermissionError):
# Skip items we can't access
continue
except (OSError, PermissionError):
pass
return tree_str
tree_structure = f"{os.path.basename(path)}/\n"
tree_structure += build_tree(path)
return {"tree": tree_structure, "path": path, "max_depth": max_depth}
except Exception as e:
return {"error": f"Unexpected error generating directory tree: {str(e)}"}
if __name__ == "__main__":
# Run the server with stdio transport
mcp.run(transport="stdio")

View File

View File

@@ -0,0 +1,4 @@
from pathlib import Path
evaluation_path = Path(__file__).parent.parent.absolute() # evaluation/
project_path = evaluation_path.parent.absolute() # claude-context/

View File

@@ -0,0 +1,129 @@
import os
import json
from pathlib import Path
import re
import logging
from git import Repo
from filelock import FileLock
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
def get_remaining_instances(instances, output_file):
"""
Filters a list of instances to exclude those that have already been processed and saved in a file.
Args:
instances (List[Dict]): A list of instances, where each instance is a dictionary with an "instance_id" key.
output_file (Path): The path to the file where the processed instances are saved.
Returns:
List[Dict]: A list of instances that have not been processed yet.
"""
instance_ids = set()
remaining_instances = list()
if output_file.exists():
with FileLock(output_file.as_posix() + ".lock"):
with open(output_file) as f:
for line in f:
instance = json.loads(line)
instance_id = instance["instance_id"]
instance_ids.add(instance_id)
logger.warning(
f"Found {len(instance_ids)} existing instances in {output_file}. Will skip them."
)
else:
output_file.parent.mkdir(parents=True, exist_ok=True)
return instances
for instance in instances:
instance_id = instance["instance_id"]
if instance_id not in instance_ids:
remaining_instances.append(instance)
return remaining_instances
def is_test(name, test_phrases=None):
if test_phrases is None:
test_phrases = ["test", "tests", "testing"]
words = set(re.split(r" |_|\/|\.", name.lower()))
return any(word in words for word in test_phrases)
def list_files(root_dir, include_tests=False):
files = []
for filename in Path(root_dir).rglob("*.py"):
if not include_tests and is_test(filename.as_posix()):
continue
files.append(filename.relative_to(root_dir).as_posix())
return files
class ContextManager:
"""
A context manager for managing a Git repository at a specific commit.
Args:
repo_path (str): The path to the Git repository.
base_commit (str): The commit hash to switch to.
verbose (bool, optional): Whether to print verbose output. Defaults to False.
Attributes:
repo_path (str): The path to the Git repository.
base_commit (str): The commit hash to switch to.
verbose (bool): Whether to print verbose output.
repo (git.Repo): The Git repository object.
Methods:
__enter__(): Switches to the specified commit and returns the context manager object.
get_readme_files(): Returns a list of filenames for all README files in the repository.
__exit__(exc_type, exc_val, exc_tb): Does nothing.
"""
def __init__(self, repo_path, base_commit, verbose=False):
self.repo_path = Path(repo_path).resolve().as_posix()
self.base_commit = base_commit
self.verbose = verbose
self.repo = Repo(self.repo_path)
def __enter__(self):
if self.verbose:
print(f"Switching to {self.base_commit}")
try:
self.repo.git.reset("--hard", self.base_commit)
self.repo.git.clean("-fdxq")
except Exception as e:
logger.error(f"Failed to switch to {self.base_commit}")
logger.error(e)
raise e
return self
def get_readme_files(self):
files = os.listdir(self.repo_path)
files = list(filter(lambda x: os.path.isfile(x), files))
files = list(filter(lambda x: x.lower().startswith("readme"), files))
return files
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def clone_repo(repo, root_dir, token):
"""
Clones a GitHub repository to a specified directory.
Args:
repo (str): The GitHub repository to clone.
root_dir (str): The root directory to clone the repository to.
token (str): The GitHub personal access token to use for authentication.
Returns:
Path: The path to the cloned repository directory.
"""
repo_dir = Path(root_dir, f"repo__{repo.replace('/', '__')}")
if not repo_dir.exists():
repo_url = f"https://{token}@github.com/{repo}.git"
logger.info(f"Cloning {repo} {os.getpid()}")
Repo.clone_from(repo_url, repo_dir)
return repo_dir

451
evaluation/utils/format.py Normal file
View File

@@ -0,0 +1,451 @@
import json
import re
import os
def extract_final_answer(response):
"""Extract the final answer from the agent response"""
if "messages" in response:
messages = response["messages"]
# Get the last AI message
for message in reversed(messages):
if hasattr(message, "content") and isinstance(message.content, str):
return message.content
elif hasattr(message, "content") and isinstance(message.content, list):
# Handle structured content
for content_item in message.content:
if (
isinstance(content_item, dict)
and content_item.get("type") == "text"
):
return content_item.get("text", "")
return "No answer found"
def extract_file_paths_from_edits(response, codebase_path):
"""Extract file paths from edit tool responses and convert to relative paths"""
import re
file_paths = []
seen_relative_paths = set() # Use set for faster lookup
codebase_path = os.path.abspath(codebase_path)
# Extract the entire conversation content
if hasattr(response, "get") and "messages" in response:
# Handle LangGraph response format
content = ""
for message in response["messages"]:
if hasattr(message, "content"):
content += str(message.content) + "\n"
elif isinstance(message, dict) and "content" in message:
content += str(message["content"]) + "\n"
else:
# Fallback for other response formats
content = str(response)
# Pattern to match "Successfully modified file: /path/to/file"
edit_pattern = r"Successfully modified file:\s*(.+?)(?:\s|$)"
# Also check for edit tool calls in the response
# Pattern to match edit tool calls with file_path parameter
tool_call_pattern = r"edit.*?file_path[\"']?\s*:\s*[\"']([^\"']+)[\"']"
for line in content.split("\n"):
# Check for "Successfully modified file:" pattern
match = re.search(edit_pattern, line.strip())
if match:
file_path = match.group(1).strip()
# Convert to relative path immediately for deduplication
rel_path = _normalize_to_relative_path(file_path, codebase_path)
if rel_path and rel_path not in seen_relative_paths:
seen_relative_paths.add(rel_path)
file_paths.append(rel_path)
# Check for edit tool calls
match = re.search(tool_call_pattern, line.strip(), re.IGNORECASE)
if match:
file_path = match.group(1).strip()
# Convert to relative path immediately for deduplication
rel_path = _normalize_to_relative_path(file_path, codebase_path)
if rel_path and rel_path not in seen_relative_paths:
seen_relative_paths.add(rel_path)
file_paths.append(rel_path)
return file_paths
def _normalize_to_relative_path(file_path, codebase_path):
"""Convert a file path to relative path based on codebase_path"""
if isinstance(file_path, str):
if os.path.isabs(file_path):
# Absolute path - convert to relative
abs_path = os.path.abspath(file_path)
if abs_path.startswith(codebase_path):
return os.path.relpath(abs_path, codebase_path)
else:
# Path outside codebase, return as-is
return file_path
else:
# Already relative path
return file_path
return None
def extract_oracle_files_from_patch(patch):
"""Extract the list of oracle files from the patch field"""
import re
if not patch:
return []
# Pattern to match patch headers like "--- a/path/to/file"
patch_files_pattern = re.compile(r"\-\-\- a/(.+)")
oracle_files = list(set(patch_files_pattern.findall(patch)))
return oracle_files
def extract_edit_calls_from_conversation_log(log_content: str):
"""Extract all edit tool calls from conversation log content"""
import re
edit_calls = []
# Split content into lines for processing
lines = log_content.split("\n")
i = 0
while i < len(lines):
line = lines[i]
# Look for Arguments: line with edit tool (may have leading whitespace)
if "Arguments:" in line and "'file_path'" in line:
# Collect the full arguments block (might span multiple lines)
args_block = line
# Check if the line contains complete arguments
if "}" in line:
# Arguments are on a single line
args_text = line
else:
# Arguments span multiple lines
j = i + 1
while j < len(lines) and "}" not in lines[j]:
args_block += (
"\n" + lines[j]
) # Keep original formatting including newlines
j += 1
if j < len(lines):
args_block += "\n" + lines[j]
args_text = args_block
# Extract file_path, old_string, new_string using regex
file_path_match = re.search(r"'file_path':\s*'([^']*)'", args_text)
# old_string can be either single-quoted or double-quoted
old_string_match = re.search(
r"'old_string':\s*[\"'](.*?)[\"'](?=,\s*'new_string')",
args_text,
re.DOTALL,
)
# new_string can be either single-quoted or double-quoted
new_string_match = re.search(
r"'new_string':\s*[\"'](.*?)[\"'](?=\s*})", args_text, re.DOTALL
)
if file_path_match and old_string_match and new_string_match:
file_path = file_path_match.group(1)
old_string = old_string_match.group(1)
new_string = new_string_match.group(1)
# Unescape newlines and clean up strings
old_string = old_string.replace("\\n", "\n").replace("\\'", "'")
new_string = new_string.replace("\\n", "\n").replace("\\'", "'")
edit_calls.append(
{
"file_path": file_path,
"old_string": old_string,
"new_string": new_string,
}
)
i += 1
return edit_calls
def find_line_number_for_old_string(file_path: str, old_string: str):
"""Find the line number where old_string starts in the file"""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Find the position of old_string in the content
pos = content.find(old_string)
if pos == -1:
return None
# Count lines up to that position
line_num = content[:pos].count("\n") + 1
return line_num
except Exception:
return None
def generate_unified_diff(file_path: str, old_string: str, new_string: str):
"""Generate unified diff format for a single edit"""
import difflib
import os
# Get the relative file path for cleaner display
rel_path = os.path.relpath(file_path) if os.path.exists(file_path) else file_path
# Find line number where change occurs
start_line = find_line_number_for_old_string(file_path, old_string)
# Split strings into lines for difflib
old_lines = old_string.splitlines(keepends=True)
new_lines = new_string.splitlines(keepends=True)
# Generate diff with context
diff_lines = list(
difflib.unified_diff(
old_lines,
new_lines,
fromfile=f"a/{rel_path}",
tofile=f"b/{rel_path}",
lineterm="",
n=3, # 3 lines of context
)
)
# If we found the line number, add it as a comment
result = []
if start_line is not None:
result.append(f"# Edit starting at line {start_line}")
result.extend(diff_lines)
return "\n".join(result)
def create_unified_diff_file(instance_dir: str, conversation_summary: str) -> None:
"""Create a unified diff file from conversation log content"""
edit_calls = extract_edit_calls_from_conversation_log(conversation_summary)
if not edit_calls:
return
diff_content = []
diff_content.append("# Unified diff of all edits made during retrieval")
diff_content.append("# Generated from conversation log")
diff_content.append("")
for i, edit_call in enumerate(edit_calls, 1):
diff_content.append(f"# Edit {i}: {edit_call['file_path']}")
diff_content.append("")
unified_diff = generate_unified_diff(
edit_call["file_path"], edit_call["old_string"], edit_call["new_string"]
)
diff_content.append(unified_diff)
diff_content.append("")
diff_content.append("=" * 80)
diff_content.append("")
# Write to changes.diff file
diff_file = os.path.join(instance_dir, "changes.diff")
with open(diff_file, "w", encoding="utf-8") as f:
f.write("\n".join(diff_content))
def calculate_total_tokens(response):
"""Calculate total token usage from the response"""
total_input_tokens = 0
total_output_tokens = 0
total_tokens = 0
max_single_turn_tokens = 0
if "messages" in response:
messages = response["messages"]
for message in messages:
current_turn_tokens = 0
# Check for usage metadata in AI messages
if hasattr(message, "usage_metadata"):
usage = message.usage_metadata
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
turn_total = usage.get("total_tokens", input_tokens + output_tokens)
total_input_tokens += input_tokens
total_output_tokens += output_tokens
total_tokens += turn_total
current_turn_tokens = turn_total
# Also check response_metadata for additional usage info
elif (
hasattr(message, "response_metadata")
and "usage" in message.response_metadata
):
usage = message.response_metadata["usage"]
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
total_input_tokens += input_tokens
total_output_tokens += output_tokens
# Calculate total if not provided
if "total_tokens" in usage:
turn_total = usage["total_tokens"]
total_tokens += turn_total
else:
turn_total = input_tokens + output_tokens
total_tokens += turn_total
current_turn_tokens = turn_total
# Track maximum single turn tokens
if current_turn_tokens > max_single_turn_tokens:
max_single_turn_tokens = current_turn_tokens
return {
"input_tokens": total_input_tokens,
"output_tokens": total_output_tokens,
"total_tokens": (
total_tokens
if total_tokens > 0
else total_input_tokens + total_output_tokens
),
"max_single_turn_tokens": max_single_turn_tokens,
}
def print_token_usage(response):
"""Print simple token usage statistics"""
usage = calculate_total_tokens(response)
print(f"📥 Input Tokens: {usage['input_tokens']:,}")
print(f"📤 Output Tokens: {usage['output_tokens']:,}")
print(f"🔢 Total Tokens: {usage['total_tokens']:,}")
print(f"🎯 Max Single Turn: {usage['max_single_turn_tokens']:,}")
def truncate_long_content(content, max_lines=30):
"""Truncate content if it exceeds max_lines"""
if not isinstance(content, str):
content = str(content)
lines = content.split("\n")
if len(lines) <= max_lines:
return content
truncated = "\n".join(lines[:max_lines])
remaining_lines = len(lines) - max_lines
return f"{truncated}\n... {remaining_lines} more lines"
def extract_conversation_summary(response):
"""Extract conversation summary and return as (summary_string, tool_stats_dict)"""
summary_lines = []
tool_call_counts = {} # Count of calls for each tool
total_tool_calls = 0 # Total number of tool calls
if "messages" in response:
messages = response["messages"]
summary_lines.append("📝 Conversation Summary:")
summary_lines.append("=" * 50)
for i, message in enumerate(messages):
if hasattr(message, "content"):
if hasattr(message, "role") or "Human" in str(type(message)):
# Human message
content = (
message.content
if isinstance(message.content, str)
else str(message.content)
)
summary_lines.append(f"👤 User: {content}")
summary_lines.append("=" * 50)
elif "AI" in str(type(message)):
# AI message - extract text content
if isinstance(message.content, str):
summary_lines.append(f"🤖 LLM: {message.content}")
summary_lines.append("=" * 50)
elif isinstance(message.content, list):
for content_item in message.content:
if isinstance(content_item, dict):
if content_item.get("type") == "text":
summary_lines.append(
f"🤖 LLM: {content_item.get('text', '')}"
)
summary_lines.append("=" * 50)
elif content_item.get("type") == "tool_use":
tool_name = content_item.get("name", "unknown")
tool_input = content_item.get("input", {})
tool_id = content_item.get("id", "unknown")
# Count tool calls
tool_call_counts[tool_name] = (
tool_call_counts.get(tool_name, 0) + 1
)
total_tool_calls += 1
summary_lines.append(f"🔧 Tool Call: '{tool_name}'")
summary_lines.append(f" ID: {tool_id}")
summary_lines.append(f" Arguments: {tool_input}")
summary_lines.append("=" * 50)
# Also check for tool_calls attribute (LangChain format)
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.get("name", "unknown")
tool_args = tool_call.get("args", {})
tool_id = tool_call.get("id", "unknown")
# Count tool calls
tool_call_counts[tool_name] = (
tool_call_counts.get(tool_name, 0) + 1
)
total_tool_calls += 1
summary_lines.append(f"🔧 Tool Call: '{tool_name}'")
summary_lines.append(f" ID: {tool_id}")
summary_lines.append(f" Arguments: {tool_args}")
summary_lines.append("=" * 50)
elif "Tool" in str(type(message)):
# Tool response
tool_name = getattr(message, "name", "unknown")
tool_call_id = getattr(message, "tool_call_id", "unknown")
content = getattr(message, "content", "no result")
# Truncate long content
truncated_content = truncate_long_content(content, max_lines=30)
summary_lines.append(f"⚙️ Tool Response: '{tool_name}'")
summary_lines.append(f" Call ID: {tool_call_id}")
summary_lines.append(f" Result: {truncated_content}")
summary_lines.append("=" * 50)
# Build tool statistics
tool_stats = {
"tool_call_counts": tool_call_counts,
"total_tool_calls": total_tool_calls,
}
return "\n".join(summary_lines), tool_stats
def print_conversation_summary(response):
"""Print a clean summary of the conversation"""
summary, tool_stats = extract_conversation_summary(response)
print(summary)
print("\n🔧 Tool Usage Statistics:")
print(f" Total tool calls: {tool_stats['total_tool_calls']}")
if tool_stats["tool_call_counts"]:
for tool_name, count in tool_stats["tool_call_counts"].items():
print(f" {tool_name}: {count} calls")

View File

@@ -0,0 +1,21 @@
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_anthropic import ChatAnthropic
import os
def llm_factory(llm_type: str, llm_model: str):
if llm_type == "openai":
return ChatOpenAI(model=llm_model)
elif llm_type == "ollama":
return ChatOllama(model=llm_model)
elif llm_type == "moonshot":
return ChatOpenAI(
model=llm_model,
base_url="https://api.moonshot.cn/v1",
api_key=os.getenv("MOONSHOT_API_KEY"),
)
elif llm_type == "anthropic":
return ChatAnthropic(model=llm_model, api_key=os.getenv("ANTHROPIC_API_KEY"))
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")

3210
evaluation/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff