[feat] add file extension filter to semantic search (#142)

Signed-off-by: ShawnZheng <shawn.zheng@zilliz.com>
This commit is contained in:
Shawn Zheng
2025-08-09 12:02:03 +08:00
committed by GitHub
parent bfc4809f9f
commit 944e07fb30
10 changed files with 124 additions and 29 deletions

View File

@@ -406,7 +406,7 @@ export class Context {
* @param topK Number of results to return
* @param threshold Similarity threshold
*/
async semanticSearch(codebasePath: string, query: string, topK: number = 5, threshold: number = 0.5): Promise<SemanticSearchResult[]> {
async semanticSearch(codebasePath: string, query: string, topK: number = 5, threshold: number = 0.5, filterExpr?: string): Promise<SemanticSearchResult[]> {
const isHybrid = this.getIsHybrid();
const searchType = isHybrid === true ? 'hybrid search' : 'semantic search';
console.log(`🔍 Executing ${searchType}: "${query}" in ${codebasePath}`);
@@ -465,7 +465,8 @@ export class Context {
strategy: 'rrf',
params: { k: 100 }
},
limit: topK
limit: topK,
filterExpr
}
);
@@ -496,7 +497,7 @@ export class Context {
const searchResults: VectorSearchResult[] = await this.vectorDatabase.search(
collectionName,
queryEmbedding.vector,
{ topK, threshold }
{ topK, threshold, filterExpr }
);
// 3. Convert to semantic search result format

View File

@@ -364,7 +364,7 @@ export class MilvusRestfulVectorDatabase implements VectorDatabase {
try {
const restfulConfig = this.config as MilvusRestfulConfig;
// Build search request according to Milvus REST API specification
const searchRequest = {
const searchRequest: any = {
collectionName,
dbName: restfulConfig.database,
data: [queryVector], // Array of query vectors
@@ -384,6 +384,11 @@ export class MilvusRestfulVectorDatabase implements VectorDatabase {
}
};
// Apply boolean expression filter if provided (e.g., fileExtension in ['.ts','.py'])
if (options?.filterExpr && options.filterExpr.trim().length > 0) {
searchRequest.filter = options.filterExpr;
}
const response = await this.makeRequest('/entities/search', 'POST', searchRequest);
// Transform response to VectorSearchResult format
@@ -651,7 +656,7 @@ export class MilvusRestfulVectorDatabase implements VectorDatabase {
// Prepare search requests according to Milvus REST API hybrid search specification
// For dense vector search - data must be array of vectors: [[0.1, 0.2, 0.3, ...]]
const search_param_1 = {
const search_param_1: any = {
data: Array.isArray(searchRequests[0].data) ? [searchRequests[0].data] : [[searchRequests[0].data]],
annsField: searchRequests[0].anns_field, // "vector"
limit: searchRequests[0].limit,
@@ -663,7 +668,7 @@ export class MilvusRestfulVectorDatabase implements VectorDatabase {
};
// For sparse vector search - data must be array of queries: ["query text"]
const search_param_2 = {
const search_param_2: any = {
data: Array.isArray(searchRequests[1].data) ? searchRequests[1].data : [searchRequests[1].data],
annsField: searchRequests[1].anns_field, // "sparse_vector"
limit: searchRequests[1].limit,
@@ -674,6 +679,12 @@ export class MilvusRestfulVectorDatabase implements VectorDatabase {
}
};
// Apply filter to both search parameters if provided
if (options?.filterExpr && options.filterExpr.trim().length > 0) {
search_param_1.filter = options.filterExpr;
search_param_2.filter = options.filterExpr;
}
const rerank_strategy = {
strategy: "rrf",
params: {
@@ -694,7 +705,7 @@ export class MilvusRestfulVectorDatabase implements VectorDatabase {
searchParams: search_param_2.searchParams
}, null, 2));
const hybridSearchRequest = {
const hybridSearchRequest: any = {
collectionName,
dbName: restfulConfig.database,
search: [search_param_1, search_param_2],
@@ -703,15 +714,6 @@ export class MilvusRestfulVectorDatabase implements VectorDatabase {
outputFields: ['id', 'content', 'relativePath', 'startLine', 'endLine', 'fileExtension', 'metadata'],
};
console.log(`🔍 Complete REST API request:`, JSON.stringify({
collectionName: hybridSearchRequest.collectionName,
dbName: hybridSearchRequest.dbName,
search_count: hybridSearchRequest.search.length,
rerank: hybridSearchRequest.rerank,
limit: hybridSearchRequest.limit,
outputFields: hybridSearchRequest.outputFields
}, null, 2));
console.log(`🔍 Executing REST API hybrid search...`);
const response = await this.makeRequest('/entities/hybrid_search', 'POST', hybridSearchRequest);

View File

@@ -234,13 +234,18 @@ export class MilvusVectorDatabase implements VectorDatabase {
async search(collectionName: string, queryVector: number[], options?: SearchOptions): Promise<VectorSearchResult[]> {
await this.ensureInitialized();
const searchParams = {
const searchParams: any = {
collection_name: collectionName,
data: [queryVector],
limit: options?.topK || 10,
output_fields: ['id', 'content', 'relativePath', 'startLine', 'endLine', 'fileExtension', 'metadata'],
};
// Apply boolean expression filter if provided (e.g., fileExtension in [".ts",".py"])
if (options?.filterExpr && options.filterExpr.trim().length > 0) {
searchParams.expr = options.filterExpr;
}
const searchResult = await this.client!.search(searchParams);
if (!searchResult.results || searchResult.results.length === 0) {
@@ -480,7 +485,7 @@ export class MilvusVectorDatabase implements VectorDatabase {
console.log(`🔍 Rerank strategy:`, JSON.stringify(rerank_strategy, null, 2));
// Execute hybrid search using the correct client.search format
const searchParams = {
const searchParams: any = {
collection_name: collectionName,
data: [search_param_1, search_param_2],
limit: options?.limit || searchRequests[0]?.limit || 10,
@@ -488,12 +493,17 @@ export class MilvusVectorDatabase implements VectorDatabase {
output_fields: ['id', 'content', 'relativePath', 'startLine', 'endLine', 'fileExtension', 'metadata'],
};
if (options?.filterExpr && options.filterExpr.trim().length > 0) {
searchParams.expr = options.filterExpr;
}
console.log(`🔍 Complete search request:`, JSON.stringify({
collection_name: searchParams.collection_name,
data_count: searchParams.data.length,
limit: searchParams.limit,
rerank: searchParams.rerank,
output_fields: searchParams.output_fields
output_fields: searchParams.output_fields,
expr: searchParams.expr
}, null, 2));
const searchResult = await this.client!.search(searchParams);

View File

@@ -14,6 +14,7 @@ export interface SearchOptions {
topK?: number;
filter?: Record<string, any>;
threshold?: number;
filterExpr?: string;
}
// New interfaces for hybrid search
@@ -27,6 +28,7 @@ export interface HybridSearchRequest {
export interface HybridSearchOptions {
rerank?: RerankStrategy;
limit?: number;
filterExpr?: string;
}
export interface RerankStrategy {

View File

@@ -410,7 +410,7 @@ export class ToolHandlers {
}
public async handleSearchCode(args: any) {
const { path: codebasePath, query, limit = 10 } = args;
const { path: codebasePath, query, limit = 10, extensionFilter } = args;
const resultLimit = limit || 10;
try {
@@ -474,12 +474,31 @@ export class ToolHandlers {
console.log(`[SEARCH] 🧠 Using embedding provider: ${embeddingProvider.getProvider()} for search`);
console.log(`[SEARCH] 🔍 Generating embeddings for query using ${embeddingProvider.getProvider()}...`);
// Build filter expression from extensionFilter list
let filterExpr: string | undefined = undefined;
if (Array.isArray(extensionFilter) && extensionFilter.length > 0) {
const cleaned = extensionFilter
.filter((v: any) => typeof v === 'string')
.map((v: string) => v.trim())
.filter((v: string) => v.length > 0);
const invalid = cleaned.filter((e: string) => !(e.startsWith('.') && e.length > 1 && !/\s/.test(e)));
if (invalid.length > 0) {
return {
content: [{ type: 'text', text: `Error: Invalid file extensions in extensionFilter: ${JSON.stringify(invalid)}. Use proper extensions like '.ts', '.py'.` }],
isError: true
};
}
const quoted = cleaned.map((e: string) => `'${e}'`).join(', ');
filterExpr = `fileExtension in [${quoted}]`;
}
// Search in the specified codebase
const searchResults = await this.context.semanticSearch(
absolutePath,
query,
Math.min(resultLimit, 50),
0.3
0.3,
filterExpr
);
console.log(`[SEARCH] ✅ Search completed! Found ${searchResults.length} results using ${embeddingProvider.getProvider()} embeddings`);

View File

@@ -180,6 +180,14 @@ This tool is versatile and can be used before completing various tasks to retrie
description: "Maximum number of results to return",
default: 10,
maximum: 50
},
extensionFilter: {
type: "array",
items: {
type: "string"
},
description: "Optional: List of file extensions to filter results. (e.g., ['.ts','.py']).",
default: []
}
},
required: ["path", "query"]

View File

@@ -61,6 +61,31 @@ export class SearchCommand {
return;
}
// Optionally prompt for file extension filters
const extensionInput = await vscode.window.showInputBox({
placeHolder: 'Optional: filter by file extensions (e.g. .ts,.py,.java) leave empty for all',
prompt: 'Enter a comma-separated list of file extensions to include',
value: ''
});
const fileExtensions = (extensionInput || '')
.split(',')
.map(e => e.trim())
.filter(Boolean);
// Validate extensions strictly and build filter expression
let filterExpr: string | undefined = undefined;
if (fileExtensions.length > 0) {
const invalid = fileExtensions.filter(e => !(e.startsWith('.') && e.length > 1 && !/\s/.test(e)));
if (invalid.length > 0) {
vscode.window.showErrorMessage(`Invalid extensions: ${invalid.join(', ')}. Use proper extensions like '.ts', '.py'.`);
return;
}
const quoted = fileExtensions.map(e => `'${e}'`).join(',');
filterExpr = `fileExtension in [${quoted}]`;
}
// Use semantic search
const query: SearchQuery = {
term: searchTerm,
@@ -71,12 +96,14 @@ export class SearchCommand {
console.log('🔍 Using semantic search...');
progress.report({ increment: 50, message: 'Executing semantic search...' });
const results = await this.context.semanticSearch(
let results = await this.context.semanticSearch(
codebasePath,
query.term,
query.limit || 20,
0.3 // similarity threshold
0.3, // similarity threshold
filterExpr
);
// No client-side filtering; filter pushed down via filter expression
progress.report({ increment: 100, message: 'Search complete!' });
@@ -139,7 +166,7 @@ export class SearchCommand {
/**
* Execute search for webview (without UI prompts)
*/
async executeForWebview(searchTerm: string, limit: number = 50): Promise<SemanticSearchResult[]> {
async executeForWebview(searchTerm: string, limit: number = 50, fileExtensions: string[] = []): Promise<SemanticSearchResult[]> {
// Get workspace root for codebase path
const workspaceFolders = vscode.workspace.workspaceFolders;
if (!workspaceFolders || workspaceFolders.length === 0) {
@@ -154,12 +181,26 @@ export class SearchCommand {
}
console.log('🔍 Using semantic search for webview...');
return await this.context.semanticSearch(
// Validate extensions strictly and build filter expression
let filterExpr: string | undefined = undefined;
if (fileExtensions && fileExtensions.length > 0) {
const invalid = fileExtensions.filter(e => !(typeof e === 'string' && e.startsWith('.') && e.length > 1 && !/\s/.test(e)));
if (invalid.length > 0) {
throw new Error(`Invalid extensions: ${invalid.join(', ')}. Use proper extensions like '.ts', '.py'.`);
}
const quoted = fileExtensions.map(e => `'${e}'`).join(',');
filterExpr = `fileExtension in [${quoted}]`;
}
let results = await this.context.semanticSearch(
codebasePath,
searchTerm,
limit,
0.3 // similarity threshold
0.3, // similarity threshold
filterExpr
);
return results;
}
/**

View File

@@ -22,6 +22,7 @@ class SemanticSearchController {
initializeElements() {
// Search view elements
this.searchInput = document.getElementById('searchInput');
this.extFilterInput = document.getElementById('extFilterInput');
this.searchButton = document.getElementById('searchButton');
this.indexButton = document.getElementById('indexButton');
this.settingsButton = document.getElementById('settingsButton');
@@ -90,10 +91,15 @@ class SemanticSearchController {
*/
performSearch() {
const text = this.searchInput.value.trim();
const extFilterRaw = (this.extFilterInput?.value || '').trim();
const extensions = extFilterRaw
? extFilterRaw.split(',').map(e => e.trim()).filter(Boolean)
: [];
if (text && !this.searchButton.disabled) {
this.vscode.postMessage({
command: 'search',
text: text
text: text,
fileExtensions: extensions
});
}
}

View File

@@ -79,7 +79,8 @@ export class SemanticSearchViewProvider implements vscode.WebviewViewProvider {
// Use search command
const searchResults = await this.searchCommand.executeForWebview(
message.text,
50
50,
Array.isArray(message.fileExtensions) ? message.fileExtensions : []
);
// Convert SemanticSearchResult[] to webview format

View File

@@ -29,8 +29,13 @@
</svg>
</button>
</div>
<label for="searchInput" class="input-label">Semantic search query</label>
<input type="text" id="searchInput" class="search-input"
placeholder="Enter your semantic search query..." />
placeholder="e.g. 'The main function of the codebase'" />
<label for="extFilterInput" class="input-label" style="margin-top: 10px; display: block;">Extension filter
(optional)</label>
<input type="text" id="extFilterInput" class="search-input" style="margin-top: 6px;"
placeholder="e.g. .ts,.py,.java" />
<button id="searchButton" class="search-button">Search</button>
<button id="indexButton" class="search-button"
style="margin-top: 4px; background-color: var(--vscode-button-secondaryBackground); color: var(--vscode-button-secondaryForeground);">