mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
fix: use anthropic provider for vertexAI (#398)
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/anthropics/anthropic-sdk-go/vertex"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
@@ -26,21 +27,30 @@ var contextLimitRegex = regexp.MustCompile(`input length and ` + "`max_tokens`"
|
||||
|
||||
type anthropicClient struct {
|
||||
providerOptions providerClientOptions
|
||||
useBedrock bool
|
||||
tp AnthropicClientType
|
||||
client anthropic.Client
|
||||
adjustedMaxTokens int // Used when context limit is hit
|
||||
}
|
||||
|
||||
type AnthropicClient ProviderClient
|
||||
|
||||
func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
|
||||
type AnthropicClientType string
|
||||
|
||||
const (
|
||||
AnthropicClientTypeNormal AnthropicClientType = "normal"
|
||||
AnthropicClientTypeBedrock AnthropicClientType = "bedrock"
|
||||
AnthropicClientTypeVertex AnthropicClientType = "vertex"
|
||||
)
|
||||
|
||||
func newAnthropicClient(opts providerClientOptions, tp AnthropicClientType) AnthropicClient {
|
||||
return &anthropicClient{
|
||||
providerOptions: opts,
|
||||
client: createAnthropicClient(opts, useBedrock),
|
||||
tp: tp,
|
||||
client: createAnthropicClient(opts, tp),
|
||||
}
|
||||
}
|
||||
|
||||
func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
|
||||
func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) anthropic.Client {
|
||||
anthropicClientOptions := []option.RequestOption{}
|
||||
|
||||
// Check if Authorization header is provided in extra headers
|
||||
@@ -67,8 +77,13 @@ func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropi
|
||||
} else if hasBearerAuth {
|
||||
slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
|
||||
}
|
||||
if useBedrock {
|
||||
switch tp {
|
||||
case AnthropicClientTypeBedrock:
|
||||
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
|
||||
case AnthropicClientTypeVertex:
|
||||
project := opts.extraParams["project"]
|
||||
location := opts.extraParams["location"]
|
||||
anthropicClientOptions = append(anthropicClientOptions, vertex.WithGoogleAuth(context.Background(), location, project))
|
||||
}
|
||||
for key, header := range opts.extraHeaders {
|
||||
anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(key, header))
|
||||
@@ -478,7 +493,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
|
||||
}
|
||||
a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
|
||||
a.client = createAnthropicClient(a.providerOptions, a.tp)
|
||||
return true, 0, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
opts.disableCache = true // Disable cache for Bedrock
|
||||
return &bedrockClient{
|
||||
providerOptions: opts,
|
||||
childProvider: newAnthropicClient(anthropicOpts, true),
|
||||
childProvider: newAnthropicClient(anthropicOpts, AnthropicClientTypeBedrock),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
case catwalk.TypeAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
client: newAnthropicClient(clientOptions, false),
|
||||
client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
|
||||
}, nil
|
||||
case catwalk.TypeOpenAI:
|
||||
return &baseProvider[OpenAIClient]{
|
||||
|
||||
@@ -3,6 +3,7 @@ package provider
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
@@ -22,6 +23,10 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
|
||||
return nil
|
||||
}
|
||||
|
||||
model := opts.model(opts.modelType)
|
||||
if strings.Contains(model.ID, "anthropic") || strings.Contains(model.ID, "claude-sonnet") {
|
||||
return newAnthropicClient(opts, AnthropicClientTypeVertex)
|
||||
}
|
||||
return &geminiClient{
|
||||
providerOptions: opts,
|
||||
client: client,
|
||||
|
||||
Reference in New Issue
Block a user