Taciturnaxolotl/custom anthropic providers (#300)

* feat: support anthropic provider type in custom provider configs
* docs: fix provider configuration field name and add anthropic example
- Change `provider_type` to `type` in documentation to match actual struct field
- Add comprehensive examples for both OpenAI and Anthropic custom providers
- Include missing `api_key` field in examples for completeness
* feat: resolve headers to allow for custom scripts and such in headers
* feat: allow headers in the anthropic client
* feat: if api_key has "Bearer " in front then using it as an
Authorization header and skip the X-API-Key header in the anthropic
client
* feat: add support for templating in the config resolve.go
something like `Bearer $(echo $ENVVAR)-$(bash ~/.config/crush/script.sh)` would work now; also added some tests since the first iteration of this broke stuff majorly lol
* feat: add a system prompt prefix option to the config
---------
Co-authored-by: Kieran Klukas <me@dunkirk.sh>
Co-authored-by: Kieran Klukas <l41cge3m@duck.com>
This commit is contained in:
Kujtim Hoxha
2025-07-25 11:52:00 +02:00
committed by GitHub
parent adcf0e1b51
commit 8c874293c9
10 changed files with 427 additions and 59 deletions

View File

@@ -135,7 +135,6 @@ Crush supports Model Context Protocol (MCP) servers through three transport type
### Logging
Enable debug logging with the `-d` flag or in config. View logs with `crush logs`. Logs are stored in `.crush/logs/crush.log`.
```bash
# Run with debug logging
crush -d
@@ -186,16 +185,21 @@ The `allowed_tools` array accepts:
You can also skip all permission prompts entirely by running Crush with the `--yolo` flag.
### OpenAI-Compatible APIs
### Custom Providers
Crush supports all OpenAI-compatible APIs. Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment.
Crush supports custom provider configurations for both OpenAI-compatible and Anthropic-compatible APIs.
#### OpenAI-Compatible APIs
Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment.
```json
{
"providers": {
"deepseek": {
"provider_type": "openai",
"type": "openai",
"base_url": "https://api.deepseek.com/v1",
"api_key": "$DEEPSEEK_API_KEY",
"models": [
{
"id": "deepseek-chat",
@@ -213,6 +217,38 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D
}
```
#### Anthropic-Compatible APIs
You can also configure custom Anthropic-compatible providers:
```json
{
"providers": {
"custom-anthropic": {
"type": "anthropic",
"base_url": "https://api.anthropic.com/v1",
"api_key": "$ANTHROPIC_API_KEY",
"extra_headers": {
"anthropic-version": "2023-06-01"
},
"models": [
{
"id": "claude-3-sonnet",
"model": "Claude 3 Sonnet",
"cost_per_1m_in": 3000,
"cost_per_1m_out": 15000,
"cost_per_1m_in_cached": 300,
"cost_per_1m_out_cached": 15000,
"context_window": 200000,
"default_max_tokens": 4096,
"supports_attachments": true
}
]
}
}
}
```
## Whatcha think?
Wed love to hear your thoughts on this project. Feel free to drop us a note!

View File

@@ -77,6 +77,9 @@ type ProviderConfig struct {
// Marks the provider as disabled.
Disable bool `json:"disable,omitempty"`
// Custom system prompt prefix.
SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
// Extra headers to send with each request to the provider.
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
// Extra body

View File

@@ -232,7 +232,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
c.Providers.Del(id)
continue
}
if providerConfig.Type != catwalk.TypeOpenAI {
if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic {
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
c.Providers.Del(id)
continue

View File

@@ -613,6 +613,35 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
})
t.Run("custom anthropic provider is supported", func(t *testing.T) {
cfg := &Config{
Providers: map[string]ProviderConfig{
"custom-anthropic": {
APIKey: "test-key",
BaseURL: "https://api.anthropic.com/v1",
Type: catwalk.TypeAnthropic,
Models: []catwalk.Model{{
ID: "claude-3-sonnet",
}},
},
},
}
cfg.setDefaults("/tmp")
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
assert.Len(t, cfg.Providers, 1)
customProvider, exists := cfg.Providers["custom-anthropic"]
assert.True(t, exists)
assert.Equal(t, "custom-anthropic", customProvider.ID)
assert.Equal(t, "test-key", customProvider.APIKey)
assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
})
t.Run("disabled custom provider is removed", func(t *testing.T) {
cfg := &Config{
Providers: csync.NewMapFrom(map[string]ProviderConfig{

View File

@@ -35,34 +35,120 @@ func NewShellVariableResolver(env env.Env) VariableResolver {
}
// ResolveValue is a method for resolving values, such as environment variables.
// it will expect strings that start with `$` to be resolved as environment variables or shell commands.
// if the string does not start with `$`, it will return the string as is.
// it will resolve shell-like variable substitution anywhere in the string, including:
// - $(command) for command substitution
// - $VAR or ${VAR} for environment variables
func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
if !strings.HasPrefix(value, "$") {
// Special case: lone $ is an error (backward compatibility)
if value == "$" {
return "", fmt.Errorf("invalid value format: %s", value)
}
// If no $ found, return as-is
if !strings.Contains(value, "$") {
return value, nil
}
if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") {
command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")")
result := value
// Handle command substitution: $(command)
for {
start := strings.Index(result, "$(")
if start == -1 {
break
}
// Find matching closing parenthesis
depth := 0
end := -1
for i := start + 2; i < len(result); i++ {
if result[i] == '(' {
depth++
} else if result[i] == ')' {
if depth == 0 {
end = i
break
}
depth--
}
}
if end == -1 {
return "", fmt.Errorf("unmatched $( in value: %s", value)
}
command := result[start+2 : end]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
stdout, _, err := r.shell.Exec(ctx, command)
cancel()
if err != nil {
return "", fmt.Errorf("command execution failed: %w", err)
return "", fmt.Errorf("command execution failed for '%s': %w", command, err)
}
return strings.TrimSpace(stdout), nil
// Replace the $(command) with the output
replacement := strings.TrimSpace(stdout)
result = result[:start] + replacement + result[end+1:]
}
if after, ok := strings.CutPrefix(value, "$"); ok {
varName := after
value = r.env.Get(varName)
if value == "" {
// Handle environment variables: $VAR and ${VAR}
searchStart := 0
for {
start := strings.Index(result[searchStart:], "$")
if start == -1 {
break
}
start += searchStart // Adjust for the offset
// Skip if this is part of $( which we already handled
if start+1 < len(result) && result[start+1] == '(' {
// Skip past this $(...)
searchStart = start + 1
continue
}
var varName string
var end int
if start+1 < len(result) && result[start+1] == '{' {
// Handle ${VAR} format
closeIdx := strings.Index(result[start+2:], "}")
if closeIdx == -1 {
return "", fmt.Errorf("unmatched ${ in value: %s", value)
}
varName = result[start+2 : start+2+closeIdx]
end = start + 2 + closeIdx + 1
} else {
// Handle $VAR format - variable names must start with letter or underscore
if start+1 >= len(result) {
return "", fmt.Errorf("incomplete variable reference at end of string: %s", value)
}
if result[start+1] != '_' &&
(result[start+1] < 'a' || result[start+1] > 'z') &&
(result[start+1] < 'A' || result[start+1] > 'Z') {
return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value)
}
end = start + 1
for end < len(result) && (result[end] == '_' ||
(result[end] >= 'a' && result[end] <= 'z') ||
(result[end] >= 'A' && result[end] <= 'Z') ||
(result[end] >= '0' && result[end] <= '9')) {
end++
}
varName = result[start+1 : end]
}
envValue := r.env.Get(varName)
if envValue == "" {
return "", fmt.Errorf("environment variable %q not set", varName)
}
return value, nil
result = result[:start] + envValue + result[end:]
searchStart = start + len(envValue) // Continue searching after the replacement
}
return "", fmt.Errorf("invalid value format: %s", value)
return result, nil
}
type environmentVariableResolver struct {

View File

@@ -47,17 +47,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
envVars: map[string]string{},
expectError: true,
},
{
name: "shell command execution",
value: "$(echo hello)",
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
if command == "echo hello" {
return "hello\n", "", nil
}
return "", "", errors.New("unexpected command")
},
expected: "hello",
},
{
name: "shell command with whitespace trimming",
value: "$(echo ' spaced ')",
@@ -104,6 +94,171 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
}
}
func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
tests := []struct {
name string
value string
envVars map[string]string
shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
expected string
expectError bool
}{
{
name: "command substitution within string",
value: "Bearer $(echo token123)",
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
if command == "echo token123" {
return "token123\n", "", nil
}
return "", "", errors.New("unexpected command")
},
expected: "Bearer token123",
},
{
name: "environment variable within string",
value: "Bearer $TOKEN",
envVars: map[string]string{"TOKEN": "sk-ant-123"},
expected: "Bearer sk-ant-123",
},
{
name: "environment variable with braces within string",
value: "Bearer ${TOKEN}",
envVars: map[string]string{"TOKEN": "sk-ant-456"},
expected: "Bearer sk-ant-456",
},
{
name: "mixed command and environment substitution",
value: "$USER-$(date +%Y)-$HOST",
envVars: map[string]string{
"USER": "testuser",
"HOST": "localhost",
},
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
if command == "date +%Y" {
return "2024\n", "", nil
}
return "", "", errors.New("unexpected command")
},
expected: "testuser-2024-localhost",
},
{
name: "multiple command substitutions",
value: "$(echo hello) $(echo world)",
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
switch command {
case "echo hello":
return "hello\n", "", nil
case "echo world":
return "world\n", "", nil
}
return "", "", errors.New("unexpected command")
},
expected: "hello world",
},
{
name: "nested parentheses in command",
value: "$(echo $(echo inner))",
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
if command == "echo $(echo inner)" {
return "nested\n", "", nil
}
return "", "", errors.New("unexpected command")
},
expected: "nested",
},
{
name: "lone dollar with non-variable chars",
value: "prefix$123suffix", // Numbers can't start variable names
expectError: true,
},
{
name: "dollar with special chars",
value: "a$@b$#c", // Special chars aren't valid in variable names
expectError: true,
},
{
name: "empty environment variable substitution",
value: "Bearer $EMPTY_VAR",
envVars: map[string]string{},
expectError: true,
},
{
name: "unmatched command substitution opening",
value: "Bearer $(echo test",
expectError: true,
},
{
name: "unmatched environment variable braces",
value: "Bearer ${TOKEN",
expectError: true,
},
{
name: "command substitution with error",
value: "Bearer $(false)",
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
return "", "", errors.New("command failed")
},
expectError: true,
},
{
name: "complex real-world example",
value: "Bearer $(cat /tmp/token.txt | base64 -w 0)",
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
if command == "cat /tmp/token.txt | base64 -w 0" {
return "c2stYW50LXRlc3Q=\n", "", nil
}
return "", "", errors.New("unexpected command")
},
expected: "Bearer c2stYW50LXRlc3Q=",
},
{
name: "environment variable with underscores and numbers",
value: "Bearer $API_KEY_V2",
envVars: map[string]string{"API_KEY_V2": "sk-test-123"},
expected: "Bearer sk-test-123",
},
{
name: "no substitution needed",
value: "Bearer sk-ant-static-token",
expected: "Bearer sk-ant-static-token",
},
{
name: "incomplete variable at end",
value: "Bearer $",
expectError: true,
},
{
name: "variable with invalid character",
value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names
expectError: true,
},
{
name: "multiple invalid variables",
value: "$1$2$3",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testEnv := env.NewFromMap(tt.envVars)
resolver := &shellVariableResolver{
shell: &mockShell{execFunc: tt.shellFunc},
env: testEnv,
}
result, err := resolver.ResolveValue(tt.value)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
tests := []struct {
name string

View File

@@ -39,8 +39,30 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl
func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
anthropicClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
// Check if Authorization header is provided in extra headers
hasBearerAuth := false
if opts.extraHeaders != nil {
for key := range opts.extraHeaders {
if strings.ToLower(key) == "authorization" {
hasBearerAuth = true
break
}
}
}
isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ")
if opts.apiKey != "" && !hasBearerAuth {
if isBearerToken {
slog.Debug("API key starts with 'Bearer ', using as Authorization header")
anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey))
} else {
// Use standard X-Api-Key header
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
}
} else if hasBearerAuth {
slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
}
if useBedrock {
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
@@ -200,6 +222,25 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
maxTokens = int64(a.adjustedMaxTokens)
}
systemBlocks := []anthropic.TextBlockParam{}
// Add custom system prompt prefix if configured
if a.providerOptions.systemPromptPrefix != "" {
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
Text: a.providerOptions.systemPromptPrefix,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
})
}
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
Text: a.providerOptions.systemMessage,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
})
return anthropic.MessageNewParams{
Model: anthropic.Model(model.ID),
MaxTokens: maxTokens,
@@ -207,14 +248,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
Messages: messages,
Tools: tools,
Thinking: thinkingParam,
System: []anthropic.TextBlockParam{
{
Text: a.providerOptions.systemMessage,
CacheControl: anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
},
},
},
System: systemBlocks,
}
}
@@ -393,6 +427,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
close(eventChan)
return
}
// If there is an error we are going to see if we can retry the call
retry, after, retryErr := a.shouldRetry(attempts, err)
if retryErr != nil {

View File

@@ -180,12 +180,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if modelConfig.MaxTokens > 0 {
maxTokens = modelConfig.MaxTokens
}
systemMessage := g.providerOptions.systemMessage
if g.providerOptions.systemPromptPrefix != "" {
systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
Parts: []*genai.Part{{Text: systemMessage}},
},
}
config.Tools = g.convertTools(tools)
@@ -280,12 +284,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if g.providerOptions.maxTokens > 0 {
maxTokens = g.providerOptions.maxTokens
}
systemMessage := g.providerOptions.systemMessage
if g.providerOptions.systemPromptPrefix != "" {
systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
}
history := geminiMessages[:len(geminiMessages)-1] // All but last message
lastMsg := geminiMessages[len(geminiMessages)-1]
config := &genai.GenerateContentConfig{
MaxOutputTokens: int32(maxTokens),
SystemInstruction: &genai.Content{
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
Parts: []*genai.Part{{Text: systemMessage}},
},
}
config.Tools = g.convertTools(tools)

View File

@@ -57,7 +57,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
// Add system message first
openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
systemMessage := o.providerOptions.systemMessage
if o.providerOptions.systemPromptPrefix != "" {
systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
}
openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
for _, msg := range messages {
switch msg.Role {

View File

@@ -61,17 +61,18 @@ type Provider interface {
}
type providerClientOptions struct {
baseURL string
config config.ProviderConfig
apiKey string
modelType config.SelectedModelType
model func(config.SelectedModelType) catwalk.Model
disableCache bool
systemMessage string
maxTokens int64
extraHeaders map[string]string
extraBody map[string]any
extraParams map[string]string
baseURL string
config config.ProviderConfig
apiKey string
modelType config.SelectedModelType
model func(config.SelectedModelType) catwalk.Model
disableCache bool
systemMessage string
systemPromptPrefix string
maxTokens int64
extraHeaders map[string]string
extraBody map[string]any
extraParams map[string]string
}
type ProviderClientOption func(*providerClientOptions)
@@ -143,12 +144,23 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
}
// Resolve extra headers
resolvedExtraHeaders := make(map[string]string)
for key, value := range cfg.ExtraHeaders {
resolvedValue, err := config.Get().Resolve(value)
if err != nil {
return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
}
resolvedExtraHeaders[key] = resolvedValue
}
clientOptions := providerClientOptions{
baseURL: cfg.BaseURL,
config: cfg,
apiKey: resolvedAPIKey,
extraHeaders: cfg.ExtraHeaders,
extraBody: cfg.ExtraBody,
baseURL: cfg.BaseURL,
config: cfg,
apiKey: resolvedAPIKey,
extraHeaders: resolvedExtraHeaders,
extraBody: cfg.ExtraBody,
systemPromptPrefix: cfg.SystemPromptPrefix,
model: func(tp config.SelectedModelType) catwalk.Model {
return *config.Get().GetModelByType(tp)
},