[feat] add dimension detection for embedding providers to support custom models (#129)

Signed-off-by: ShawnZheng <shawn.zheng@zilliz.com>
This commit is contained in:
Shawn Zheng
2025-08-07 16:54:06 +08:00
committed by GitHub
parent 31aecd5f5f
commit 26f20dc2c3
6 changed files with 159 additions and 98 deletions

View File

@@ -243,11 +243,13 @@ export class Context {
* Index a codebase for semantic search
* @param codebasePath Codebase root path
* @param progressCallback Optional progress callback function
* @param forceReindex Whether to recreate the collection even if it exists
* @returns Indexing statistics
*/
async indexCodebase(
codebasePath: string,
progressCallback?: (progress: { phase: string; current: number; total: number; percentage: number }) => void
progressCallback?: (progress: { phase: string; current: number; total: number; percentage: number }) => void,
forceReindex: boolean = false
): Promise<{ indexedFiles: number; totalChunks: number; status: 'completed' | 'limit_reached' }> {
const isHybrid = this.getIsHybrid();
const searchType = isHybrid === true ? 'hybrid search' : 'semantic search';
@@ -258,8 +260,8 @@ export class Context {
// 2. Check and prepare vector collection
progressCallback?.({ phase: 'Preparing collection...', current: 0, total: 100, percentage: 0 });
console.log(`Debug2: Preparing vector collection for codebase`);
await this.prepareCollection(codebasePath);
console.log(`Debug2: Preparing vector collection for codebase${forceReindex ? ' (FORCE REINDEX)' : ''}`);
await this.prepareCollection(codebasePath, forceReindex);
// 3. Recursively traverse codebase to get all supported files
progressCallback?.({ phase: 'Scanning files...', current: 5, total: 100, percentage: 5 });
@@ -619,25 +621,29 @@ export class Context {
/**
* Prepare vector collection
*/
private async prepareCollection(codebasePath: string): Promise<void> {
private async prepareCollection(codebasePath: string, forceReindex: boolean = false): Promise<void> {
const isHybrid = this.getIsHybrid();
const collectionType = isHybrid === true ? 'hybrid vector' : 'vector';
console.log(`🔧 Preparing ${collectionType} collection for codebase: ${codebasePath}`);
console.log(`🔧 Preparing ${collectionType} collection for codebase: ${codebasePath}${forceReindex ? ' (FORCE REINDEX)' : ''}`);
const collectionName = this.getCollectionName(codebasePath);
// Check if collection already exists
const collectionExists = await this.vectorDatabase.hasCollection(collectionName);
if (collectionExists) {
if (collectionExists && !forceReindex) {
console.log(`📋 Collection ${collectionName} already exists, skipping creation`);
return;
}
// For Ollama embeddings, ensure dimension is detected before creating collection
if (this.embedding.getProvider() === 'Ollama' && typeof (this.embedding as any).initializeDimension === 'function') {
await (this.embedding as any).initializeDimension();
if (collectionExists && forceReindex) {
console.log(`🗑️ Dropping existing collection ${collectionName} for force reindex...`);
await this.vectorDatabase.dropCollection(collectionName);
console.log(`✅ Collection ${collectionName} dropped successfully`);
}
const dimension = this.embedding.getDimension();
console.log(`🔍 Detecting embedding dimension for ${this.embedding.getProvider()} provider...`);
const dimension = await this.embedding.detectDimension();
console.log(`📏 Detected dimension: ${dimension} for ${this.embedding.getProvider()}`);
const dirName = path.basename(codebasePath);
if (isHybrid === true) {

View File

@@ -31,6 +31,13 @@ export abstract class Embedding {
return text;
}
/**
* Detect embedding dimension
* @param testText Test text for dimension detection
* @returns Embedding dimension
*/
abstract detectDimension(testText?: string): Promise<number>;
/**
* Preprocess array of texts
* @param texts Array of input texts

View File

@@ -43,6 +43,11 @@ export class GeminiEmbedding extends Embedding {
}
}
async detectDimension(): Promise<number> {
// Gemini doesn't need dynamic detection, return configured dimension
return this.dimension;
}
async embed(text: string): Promise<EmbeddingVector> {
const processedText = this.preprocessText(text);
const model = this.config.model || 'gemini-embedding-001';

View File

@@ -54,46 +54,15 @@ export class OllamaEmbedding extends Embedding {
}
}
private async updateDimensionForModel(model: string): Promise<void> {
try {
// Use a dummy query to detect embedding dimension
const embedOptions: any = {
model: model,
input: 'test',
options: this.config.options,
};
// Only include keep_alive if it has a valid value
if (this.config.keepAlive && this.config.keepAlive !== '') {
embedOptions.keep_alive = this.config.keepAlive;
}
const response = await this.client.embed(embedOptions);
if (response.embeddings && response.embeddings[0]) {
this.dimension = response.embeddings[0].length;
this.dimensionDetected = true;
console.log(`📏 Detected embedding dimension: ${this.dimension} for model: ${model}`);
} else {
// Fallback to default dimension
this.dimension = 768;
this.dimensionDetected = true;
console.warn(`⚠️ Could not detect dimension for model ${model}, using default: 768`);
}
} catch (error) {
console.warn(`Failed to detect dimension for model ${model}, using default dimension 768:`, error);
this.dimension = 768;
this.dimensionDetected = true;
}
}
async embed(text: string): Promise<EmbeddingVector> {
// Preprocess the text
const processedText = this.preprocessText(text);
// Detect dimension on first use if not configured
if (!this.dimensionDetected) {
await this.updateDimensionForModel(this.config.model);
if (!this.dimensionDetected && !this.config.dimension) {
this.dimension = await this.detectDimension();
this.dimensionDetected = true;
console.log(`📏 Detected Ollama embedding dimension: ${this.dimension} for model: ${this.config.model}`);
}
const embedOptions: any = {
@@ -123,9 +92,13 @@ export class OllamaEmbedding extends Embedding {
// Preprocess all texts
const processedTexts = this.preprocessTexts(texts);
// Detect dimension on first use if not configured
if (!this.dimensionDetected) {
await this.updateDimensionForModel(this.config.model);
// Detect dimension on first use
if (!this.dimensionDetected && !this.config.dimension) {
this.dimension = await this.detectDimension();
this.dimensionDetected = true;
console.log(`📏 Detected Ollama embedding dimension: ${this.dimension} for model: ${this.config.model}`);
} else {
throw new Error('Failed to detect dimension for model ' + this.config.model);
}
// Use Ollama's native batch embedding API
@@ -172,7 +145,11 @@ export class OllamaEmbedding extends Embedding {
// Update max tokens for new model
this.setDefaultMaxTokensForModel(model);
if (!this.config.dimension) {
await this.updateDimensionForModel(model);
this.dimension = await this.detectDimension();
this.dimensionDetected = true;
console.log(`📏 Detected Ollama embedding dimension: ${this.dimension} for model: ${this.config.model}`);
} else {
console.log('Dimension already detected for model ' + this.config.model);
}
}
@@ -204,16 +181,6 @@ export class OllamaEmbedding extends Embedding {
this.config.options = options;
}
/**
* Set dimension manually
* @param dimension Embedding dimension
*/
setDimension(dimension: number): void {
this.config.dimension = dimension;
this.dimension = dimension;
this.dimensionDetected = true;
}
/**
* Set max tokens manually
* @param maxTokens Maximum number of tokens
@@ -230,13 +197,34 @@ export class OllamaEmbedding extends Embedding {
return this.client;
}
/**
* Initialize dimension detection for the current model
*/
async initializeDimension(): Promise<void> {
if (!this.config.dimension) {
await this.updateDimensionForModel(this.config.model);
async detectDimension(testText: string = "test"): Promise<number> {
console.log(`[Ollama] Detecting embedding dimension...`);
try {
const processedText = this.preprocessText(testText);
const embedOptions: any = {
model: this.config.model,
input: processedText,
options: this.config.options,
};
if (this.config.keepAlive && this.config.keepAlive !== '') {
embedOptions.keep_alive = this.config.keepAlive;
}
const response = await this.client.embed(embedOptions);
if (!response.embeddings || !response.embeddings[0]) {
throw new Error('Ollama API returned invalid response');
}
const dimension = response.embeddings[0].length;
console.log(`[Ollama] Successfully detected embedding dimension: ${dimension}`);
return dimension;
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
console.error(`[Ollama] Failed to detect dimension: ${errorMessage}`);
throw new Error(`Failed to detect Ollama embedding dimension: ${errorMessage}`);
}
}
}

View File

@@ -20,20 +20,36 @@ export class OpenAIEmbedding extends Embedding {
apiKey: config.apiKey,
baseURL: config.baseURL,
});
// Set dimension based on model
this.updateDimensionForModel(config.model || 'text-embedding-3-small');
}
private updateDimensionForModel(model: string): void {
if (model === 'text-embedding-3-small') {
this.dimension = 1536;
} else if (model === 'text-embedding-3-large') {
this.dimension = 3072;
} else if (model === 'text-embedding-ada-002') {
this.dimension = 1536;
} else {
this.dimension = 1536; // Default dimension
async detectDimension(testText: string = "test"): Promise<number> {
const model = this.config.model || 'text-embedding-3-small';
const knownModels = OpenAIEmbedding.getSupportedModels();
// Use known dimension for standard models
if (knownModels[model]) {
return knownModels[model].dimension;
}
// For custom models, make API call to detect dimension
try {
const processedText = this.preprocessText(testText);
const response = await this.client.embeddings.create({
model: model,
input: processedText,
encoding_format: 'float',
});
return response.data[0].embedding.length;
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
// Re-throw authentication errors
if (errorMessage.includes('API key') || errorMessage.includes('unauthorized') || errorMessage.includes('authentication')) {
throw new Error(`Failed to detect dimension for model ${model}: ${errorMessage}`);
}
// For other errors, throw exception instead of using fallback
throw new Error(`Failed to detect dimension for model ${model}: ${errorMessage}`);
}
}
@@ -41,32 +57,61 @@ export class OpenAIEmbedding extends Embedding {
const processedText = this.preprocessText(text);
const model = this.config.model || 'text-embedding-3-small';
const response = await this.client.embeddings.create({
model: model,
input: processedText,
encoding_format: 'float',
});
const knownModels = OpenAIEmbedding.getSupportedModels();
if (knownModels[model] && this.dimension !== knownModels[model].dimension) {
this.dimension = knownModels[model].dimension;
} else if (!knownModels[model]) {
this.dimension = await this.detectDimension();
}
return {
vector: response.data[0].embedding,
dimension: this.dimension
};
try {
const response = await this.client.embeddings.create({
model: model,
input: processedText,
encoding_format: 'float',
});
// Update dimension from actual response
this.dimension = response.data[0].embedding.length;
return {
vector: response.data[0].embedding,
dimension: this.dimension
};
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
throw new Error(`Failed to generate OpenAI embedding: ${errorMessage}`);
}
}
async embedBatch(texts: string[]): Promise<EmbeddingVector[]> {
const processedTexts = this.preprocessTexts(texts);
const model = this.config.model || 'text-embedding-3-small';
const response = await this.client.embeddings.create({
model: model,
input: processedTexts,
encoding_format: 'float',
});
const knownModels = OpenAIEmbedding.getSupportedModels();
if (knownModels[model] && this.dimension !== knownModels[model].dimension) {
this.dimension = knownModels[model].dimension;
} else if (!knownModels[model]) {
this.dimension = await this.detectDimension();
}
return response.data.map((item) => ({
vector: item.embedding,
dimension: this.dimension
}));
try {
const response = await this.client.embeddings.create({
model: model,
input: processedTexts,
encoding_format: 'float',
});
this.dimension = response.data[0].embedding.length;
return response.data.map((item) => ({
vector: item.embedding,
dimension: this.dimension
}));
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
throw new Error(`Failed to generate OpenAI batch embeddings: ${errorMessage}`);
}
}
getDimension(): number {
@@ -81,9 +126,14 @@ export class OpenAIEmbedding extends Embedding {
* Set model type
* @param model Model name
*/
setModel(model: string): void {
async setModel(model: string): Promise<void> {
this.config.model = model;
this.updateDimensionForModel(model);
const knownModels = OpenAIEmbedding.getSupportedModels();
if (knownModels[model]) {
this.dimension = knownModels[model].dimension;
} else {
this.dimension = await this.detectDimension();
}
}
/**

View File

@@ -61,6 +61,11 @@ export class VoyageAIEmbedding extends Embedding {
}
}
async detectDimension(): Promise<number> {
// VoyageAI doesn't need dynamic detection, return configured dimension
return this.dimension;
}
async embed(text: string): Promise<EmbeddingVector> {
const processedText = this.preprocessText(text);
const model = this.config.model || 'voyage-code-3';