mirror of
https://github.com/zilliztech/claude-context.git
synced 2025-10-06 01:10:02 +03:00
[feat] add dimension detection for embedding providers to support custom models (#129)
Signed-off-by: ShawnZheng <shawn.zheng@zilliz.com>
This commit is contained in:
@@ -243,11 +243,13 @@ export class Context {
|
||||
* Index a codebase for semantic search
|
||||
* @param codebasePath Codebase root path
|
||||
* @param progressCallback Optional progress callback function
|
||||
* @param forceReindex Whether to recreate the collection even if it exists
|
||||
* @returns Indexing statistics
|
||||
*/
|
||||
async indexCodebase(
|
||||
codebasePath: string,
|
||||
progressCallback?: (progress: { phase: string; current: number; total: number; percentage: number }) => void
|
||||
progressCallback?: (progress: { phase: string; current: number; total: number; percentage: number }) => void,
|
||||
forceReindex: boolean = false
|
||||
): Promise<{ indexedFiles: number; totalChunks: number; status: 'completed' | 'limit_reached' }> {
|
||||
const isHybrid = this.getIsHybrid();
|
||||
const searchType = isHybrid === true ? 'hybrid search' : 'semantic search';
|
||||
@@ -258,8 +260,8 @@ export class Context {
|
||||
|
||||
// 2. Check and prepare vector collection
|
||||
progressCallback?.({ phase: 'Preparing collection...', current: 0, total: 100, percentage: 0 });
|
||||
console.log(`Debug2: Preparing vector collection for codebase`);
|
||||
await this.prepareCollection(codebasePath);
|
||||
console.log(`Debug2: Preparing vector collection for codebase${forceReindex ? ' (FORCE REINDEX)' : ''}`);
|
||||
await this.prepareCollection(codebasePath, forceReindex);
|
||||
|
||||
// 3. Recursively traverse codebase to get all supported files
|
||||
progressCallback?.({ phase: 'Scanning files...', current: 5, total: 100, percentage: 5 });
|
||||
@@ -619,25 +621,29 @@ export class Context {
|
||||
/**
|
||||
* Prepare vector collection
|
||||
*/
|
||||
private async prepareCollection(codebasePath: string): Promise<void> {
|
||||
private async prepareCollection(codebasePath: string, forceReindex: boolean = false): Promise<void> {
|
||||
const isHybrid = this.getIsHybrid();
|
||||
const collectionType = isHybrid === true ? 'hybrid vector' : 'vector';
|
||||
console.log(`🔧 Preparing ${collectionType} collection for codebase: ${codebasePath}`);
|
||||
console.log(`🔧 Preparing ${collectionType} collection for codebase: ${codebasePath}${forceReindex ? ' (FORCE REINDEX)' : ''}`);
|
||||
const collectionName = this.getCollectionName(codebasePath);
|
||||
|
||||
// Check if collection already exists
|
||||
const collectionExists = await this.vectorDatabase.hasCollection(collectionName);
|
||||
if (collectionExists) {
|
||||
|
||||
if (collectionExists && !forceReindex) {
|
||||
console.log(`📋 Collection ${collectionName} already exists, skipping creation`);
|
||||
return;
|
||||
}
|
||||
|
||||
// For Ollama embeddings, ensure dimension is detected before creating collection
|
||||
if (this.embedding.getProvider() === 'Ollama' && typeof (this.embedding as any).initializeDimension === 'function') {
|
||||
await (this.embedding as any).initializeDimension();
|
||||
if (collectionExists && forceReindex) {
|
||||
console.log(`🗑️ Dropping existing collection ${collectionName} for force reindex...`);
|
||||
await this.vectorDatabase.dropCollection(collectionName);
|
||||
console.log(`✅ Collection ${collectionName} dropped successfully`);
|
||||
}
|
||||
|
||||
const dimension = this.embedding.getDimension();
|
||||
console.log(`🔍 Detecting embedding dimension for ${this.embedding.getProvider()} provider...`);
|
||||
const dimension = await this.embedding.detectDimension();
|
||||
console.log(`📏 Detected dimension: ${dimension} for ${this.embedding.getProvider()}`);
|
||||
const dirName = path.basename(codebasePath);
|
||||
|
||||
if (isHybrid === true) {
|
||||
|
||||
@@ -31,6 +31,13 @@ export abstract class Embedding {
|
||||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect embedding dimension
|
||||
* @param testText Test text for dimension detection
|
||||
* @returns Embedding dimension
|
||||
*/
|
||||
abstract detectDimension(testText?: string): Promise<number>;
|
||||
|
||||
/**
|
||||
* Preprocess array of texts
|
||||
* @param texts Array of input texts
|
||||
|
||||
@@ -43,6 +43,11 @@ export class GeminiEmbedding extends Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
async detectDimension(): Promise<number> {
|
||||
// Gemini doesn't need dynamic detection, return configured dimension
|
||||
return this.dimension;
|
||||
}
|
||||
|
||||
async embed(text: string): Promise<EmbeddingVector> {
|
||||
const processedText = this.preprocessText(text);
|
||||
const model = this.config.model || 'gemini-embedding-001';
|
||||
|
||||
@@ -54,46 +54,15 @@ export class OllamaEmbedding extends Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
private async updateDimensionForModel(model: string): Promise<void> {
|
||||
try {
|
||||
// Use a dummy query to detect embedding dimension
|
||||
const embedOptions: any = {
|
||||
model: model,
|
||||
input: 'test',
|
||||
options: this.config.options,
|
||||
};
|
||||
|
||||
// Only include keep_alive if it has a valid value
|
||||
if (this.config.keepAlive && this.config.keepAlive !== '') {
|
||||
embedOptions.keep_alive = this.config.keepAlive;
|
||||
}
|
||||
|
||||
const response = await this.client.embed(embedOptions);
|
||||
|
||||
if (response.embeddings && response.embeddings[0]) {
|
||||
this.dimension = response.embeddings[0].length;
|
||||
this.dimensionDetected = true;
|
||||
console.log(`📏 Detected embedding dimension: ${this.dimension} for model: ${model}`);
|
||||
} else {
|
||||
// Fallback to default dimension
|
||||
this.dimension = 768;
|
||||
this.dimensionDetected = true;
|
||||
console.warn(`⚠️ Could not detect dimension for model ${model}, using default: 768`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn(`Failed to detect dimension for model ${model}, using default dimension 768:`, error);
|
||||
this.dimension = 768;
|
||||
this.dimensionDetected = true;
|
||||
}
|
||||
}
|
||||
|
||||
async embed(text: string): Promise<EmbeddingVector> {
|
||||
// Preprocess the text
|
||||
const processedText = this.preprocessText(text);
|
||||
|
||||
// Detect dimension on first use if not configured
|
||||
if (!this.dimensionDetected) {
|
||||
await this.updateDimensionForModel(this.config.model);
|
||||
if (!this.dimensionDetected && !this.config.dimension) {
|
||||
this.dimension = await this.detectDimension();
|
||||
this.dimensionDetected = true;
|
||||
console.log(`📏 Detected Ollama embedding dimension: ${this.dimension} for model: ${this.config.model}`);
|
||||
}
|
||||
|
||||
const embedOptions: any = {
|
||||
@@ -123,9 +92,13 @@ export class OllamaEmbedding extends Embedding {
|
||||
// Preprocess all texts
|
||||
const processedTexts = this.preprocessTexts(texts);
|
||||
|
||||
// Detect dimension on first use if not configured
|
||||
if (!this.dimensionDetected) {
|
||||
await this.updateDimensionForModel(this.config.model);
|
||||
// Detect dimension on first use
|
||||
if (!this.dimensionDetected && !this.config.dimension) {
|
||||
this.dimension = await this.detectDimension();
|
||||
this.dimensionDetected = true;
|
||||
console.log(`📏 Detected Ollama embedding dimension: ${this.dimension} for model: ${this.config.model}`);
|
||||
} else {
|
||||
throw new Error('Failed to detect dimension for model ' + this.config.model);
|
||||
}
|
||||
|
||||
// Use Ollama's native batch embedding API
|
||||
@@ -172,7 +145,11 @@ export class OllamaEmbedding extends Embedding {
|
||||
// Update max tokens for new model
|
||||
this.setDefaultMaxTokensForModel(model);
|
||||
if (!this.config.dimension) {
|
||||
await this.updateDimensionForModel(model);
|
||||
this.dimension = await this.detectDimension();
|
||||
this.dimensionDetected = true;
|
||||
console.log(`📏 Detected Ollama embedding dimension: ${this.dimension} for model: ${this.config.model}`);
|
||||
} else {
|
||||
console.log('Dimension already detected for model ' + this.config.model);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,16 +181,6 @@ export class OllamaEmbedding extends Embedding {
|
||||
this.config.options = options;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set dimension manually
|
||||
* @param dimension Embedding dimension
|
||||
*/
|
||||
setDimension(dimension: number): void {
|
||||
this.config.dimension = dimension;
|
||||
this.dimension = dimension;
|
||||
this.dimensionDetected = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set max tokens manually
|
||||
* @param maxTokens Maximum number of tokens
|
||||
@@ -230,13 +197,34 @@ export class OllamaEmbedding extends Embedding {
|
||||
return this.client;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize dimension detection for the current model
|
||||
*/
|
||||
async initializeDimension(): Promise<void> {
|
||||
if (!this.config.dimension) {
|
||||
await this.updateDimensionForModel(this.config.model);
|
||||
async detectDimension(testText: string = "test"): Promise<number> {
|
||||
console.log(`[Ollama] Detecting embedding dimension...`);
|
||||
|
||||
try {
|
||||
const processedText = this.preprocessText(testText);
|
||||
const embedOptions: any = {
|
||||
model: this.config.model,
|
||||
input: processedText,
|
||||
options: this.config.options,
|
||||
};
|
||||
|
||||
if (this.config.keepAlive && this.config.keepAlive !== '') {
|
||||
embedOptions.keep_alive = this.config.keepAlive;
|
||||
}
|
||||
|
||||
const response = await this.client.embed(embedOptions);
|
||||
|
||||
if (!response.embeddings || !response.embeddings[0]) {
|
||||
throw new Error('Ollama API returned invalid response');
|
||||
}
|
||||
|
||||
const dimension = response.embeddings[0].length;
|
||||
console.log(`[Ollama] Successfully detected embedding dimension: ${dimension}`);
|
||||
return dimension;
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
console.error(`[Ollama] Failed to detect dimension: ${errorMessage}`);
|
||||
throw new Error(`Failed to detect Ollama embedding dimension: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -20,20 +20,36 @@ export class OpenAIEmbedding extends Embedding {
|
||||
apiKey: config.apiKey,
|
||||
baseURL: config.baseURL,
|
||||
});
|
||||
|
||||
// Set dimension based on model
|
||||
this.updateDimensionForModel(config.model || 'text-embedding-3-small');
|
||||
}
|
||||
|
||||
private updateDimensionForModel(model: string): void {
|
||||
if (model === 'text-embedding-3-small') {
|
||||
this.dimension = 1536;
|
||||
} else if (model === 'text-embedding-3-large') {
|
||||
this.dimension = 3072;
|
||||
} else if (model === 'text-embedding-ada-002') {
|
||||
this.dimension = 1536;
|
||||
} else {
|
||||
this.dimension = 1536; // Default dimension
|
||||
async detectDimension(testText: string = "test"): Promise<number> {
|
||||
const model = this.config.model || 'text-embedding-3-small';
|
||||
const knownModels = OpenAIEmbedding.getSupportedModels();
|
||||
|
||||
// Use known dimension for standard models
|
||||
if (knownModels[model]) {
|
||||
return knownModels[model].dimension;
|
||||
}
|
||||
|
||||
// For custom models, make API call to detect dimension
|
||||
try {
|
||||
const processedText = this.preprocessText(testText);
|
||||
const response = await this.client.embeddings.create({
|
||||
model: model,
|
||||
input: processedText,
|
||||
encoding_format: 'float',
|
||||
});
|
||||
return response.data[0].embedding.length;
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
|
||||
// Re-throw authentication errors
|
||||
if (errorMessage.includes('API key') || errorMessage.includes('unauthorized') || errorMessage.includes('authentication')) {
|
||||
throw new Error(`Failed to detect dimension for model ${model}: ${errorMessage}`);
|
||||
}
|
||||
|
||||
// For other errors, throw exception instead of using fallback
|
||||
throw new Error(`Failed to detect dimension for model ${model}: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,32 +57,61 @@ export class OpenAIEmbedding extends Embedding {
|
||||
const processedText = this.preprocessText(text);
|
||||
const model = this.config.model || 'text-embedding-3-small';
|
||||
|
||||
const response = await this.client.embeddings.create({
|
||||
model: model,
|
||||
input: processedText,
|
||||
encoding_format: 'float',
|
||||
});
|
||||
const knownModels = OpenAIEmbedding.getSupportedModels();
|
||||
if (knownModels[model] && this.dimension !== knownModels[model].dimension) {
|
||||
this.dimension = knownModels[model].dimension;
|
||||
} else if (!knownModels[model]) {
|
||||
this.dimension = await this.detectDimension();
|
||||
}
|
||||
|
||||
return {
|
||||
vector: response.data[0].embedding,
|
||||
dimension: this.dimension
|
||||
};
|
||||
try {
|
||||
const response = await this.client.embeddings.create({
|
||||
model: model,
|
||||
input: processedText,
|
||||
encoding_format: 'float',
|
||||
});
|
||||
|
||||
// Update dimension from actual response
|
||||
this.dimension = response.data[0].embedding.length;
|
||||
|
||||
return {
|
||||
vector: response.data[0].embedding,
|
||||
dimension: this.dimension
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
throw new Error(`Failed to generate OpenAI embedding: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
|
||||
async embedBatch(texts: string[]): Promise<EmbeddingVector[]> {
|
||||
const processedTexts = this.preprocessTexts(texts);
|
||||
const model = this.config.model || 'text-embedding-3-small';
|
||||
|
||||
const response = await this.client.embeddings.create({
|
||||
model: model,
|
||||
input: processedTexts,
|
||||
encoding_format: 'float',
|
||||
});
|
||||
const knownModels = OpenAIEmbedding.getSupportedModels();
|
||||
if (knownModels[model] && this.dimension !== knownModels[model].dimension) {
|
||||
this.dimension = knownModels[model].dimension;
|
||||
} else if (!knownModels[model]) {
|
||||
this.dimension = await this.detectDimension();
|
||||
}
|
||||
|
||||
return response.data.map((item) => ({
|
||||
vector: item.embedding,
|
||||
dimension: this.dimension
|
||||
}));
|
||||
try {
|
||||
const response = await this.client.embeddings.create({
|
||||
model: model,
|
||||
input: processedTexts,
|
||||
encoding_format: 'float',
|
||||
});
|
||||
|
||||
this.dimension = response.data[0].embedding.length;
|
||||
|
||||
return response.data.map((item) => ({
|
||||
vector: item.embedding,
|
||||
dimension: this.dimension
|
||||
}));
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
|
||||
throw new Error(`Failed to generate OpenAI batch embeddings: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
|
||||
getDimension(): number {
|
||||
@@ -81,9 +126,14 @@ export class OpenAIEmbedding extends Embedding {
|
||||
* Set model type
|
||||
* @param model Model name
|
||||
*/
|
||||
setModel(model: string): void {
|
||||
async setModel(model: string): Promise<void> {
|
||||
this.config.model = model;
|
||||
this.updateDimensionForModel(model);
|
||||
const knownModels = OpenAIEmbedding.getSupportedModels();
|
||||
if (knownModels[model]) {
|
||||
this.dimension = knownModels[model].dimension;
|
||||
} else {
|
||||
this.dimension = await this.detectDimension();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -61,6 +61,11 @@ export class VoyageAIEmbedding extends Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
async detectDimension(): Promise<number> {
|
||||
// VoyageAI doesn't need dynamic detection, return configured dimension
|
||||
return this.dimension;
|
||||
}
|
||||
|
||||
async embed(text: string): Promise<EmbeddingVector> {
|
||||
const processedText = this.preprocessText(text);
|
||||
const model = this.config.model || 'voyage-code-3';
|
||||
|
||||
Reference in New Issue
Block a user