fix: use anthropic provider for vertexAI (#398)

This commit is contained in:
Kujtim Hoxha
2025-07-31 18:40:54 +02:00
committed by GitHub
parent 713a7f72a0
commit 8248e4f649
6 changed files with 43 additions and 8 deletions

5
go.mod
View File

@@ -49,11 +49,16 @@ require (
)
require (
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/time v0.8.0 // indirect
google.golang.org/api v0.211.0 // indirect
)
require (

10
go.sum
View File

@@ -2,6 +2,8 @@ cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ=
@@ -290,6 +292,8 @@ github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
@@ -328,6 +332,8 @@ golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -367,11 +373,15 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.211.0 h1:IUpLjq09jxBSV1lACO33CGY3jsRcbctfGzhj+ZSE/Bg=
google.golang.org/api v0.211.0/go.mod h1:XOloB4MXFH4UTlQSGuNUxw0UT74qdENK8d6JNsXKLi0=
google.golang.org/genai v1.3.0 h1:tXhPJF30skOjnnDY7ZnjK3q7IKy4PuAlEA0fk7uEaEI=
google.golang.org/genai v1.3.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 h1:e0AIkUUhxyBKh6ssZNrAMeqhA7RKUj42346d1y02i2g=

View File

@@ -15,6 +15,7 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/anthropics/anthropic-sdk-go/vertex"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
@@ -26,21 +27,30 @@ var contextLimitRegex = regexp.MustCompile(`input length and ` + "`max_tokens`"
type anthropicClient struct {
providerOptions providerClientOptions
useBedrock bool
tp AnthropicClientType
client anthropic.Client
adjustedMaxTokens int // Used when context limit is hit
}
type AnthropicClient ProviderClient
func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
type AnthropicClientType string
const (
AnthropicClientTypeNormal AnthropicClientType = "normal"
AnthropicClientTypeBedrock AnthropicClientType = "bedrock"
AnthropicClientTypeVertex AnthropicClientType = "vertex"
)
func newAnthropicClient(opts providerClientOptions, tp AnthropicClientType) AnthropicClient {
return &anthropicClient{
providerOptions: opts,
client: createAnthropicClient(opts, useBedrock),
tp: tp,
client: createAnthropicClient(opts, tp),
}
}
func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) anthropic.Client {
anthropicClientOptions := []option.RequestOption{}
// Check if Authorization header is provided in extra headers
@@ -67,8 +77,13 @@ func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropi
} else if hasBearerAuth {
slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
}
if useBedrock {
switch tp {
case AnthropicClientTypeBedrock:
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
case AnthropicClientTypeVertex:
project := opts.extraParams["project"]
location := opts.extraParams["location"]
anthropicClientOptions = append(anthropicClientOptions, vertex.WithGoogleAuth(context.Background(), location, project))
}
for key, header := range opts.extraHeaders {
anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(key, header))
@@ -478,7 +493,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
a.client = createAnthropicClient(a.providerOptions, a.tp)
return true, 0, nil
}

View File

@@ -52,7 +52,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
opts.disableCache = true // Disable cache for Bedrock
return &bedrockClient{
providerOptions: opts,
childProvider: newAnthropicClient(anthropicOpts, true),
childProvider: newAnthropicClient(anthropicOpts, AnthropicClientTypeBedrock),
}
}

View File

@@ -173,7 +173,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
case catwalk.TypeAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
client: newAnthropicClient(clientOptions, false),
client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
}, nil
case catwalk.TypeOpenAI:
return &baseProvider[OpenAIClient]{

View File

@@ -3,6 +3,7 @@ package provider
import (
"context"
"log/slog"
"strings"
"google.golang.org/genai"
)
@@ -22,6 +23,10 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
return nil
}
model := opts.model(opts.modelType)
if strings.Contains(model.ID, "anthropic") || strings.Contains(model.ID, "claude-sonnet") {
return newAnthropicClient(opts, AnthropicClientTypeVertex)
}
return &geminiClient{
providerOptions: opts,
client: client,