mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
Merge remote-tracking branch 'origin/main' into list
This commit is contained in:
@@ -77,6 +77,9 @@ type ProviderConfig struct {
|
||||
// Marks the provider as disabled.
|
||||
Disable bool `json:"disable,omitempty"`
|
||||
|
||||
// Custom system prompt prefix.
|
||||
SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
|
||||
|
||||
// Extra headers to send with each request to the provider.
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
|
||||
// Extra body
|
||||
|
||||
@@ -232,7 +232,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
c.Providers.Del(id)
|
||||
continue
|
||||
}
|
||||
if providerConfig.Type != catwalk.TypeOpenAI {
|
||||
if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic {
|
||||
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
|
||||
c.Providers.Del(id)
|
||||
continue
|
||||
|
||||
@@ -613,6 +613,35 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
|
||||
})
|
||||
|
||||
t.Run("custom anthropic provider is supported", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: csync.NewMapFrom(map[string]ProviderConfig{
|
||||
"custom-anthropic": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.anthropic.com/v1",
|
||||
Type: catwalk.TypeAnthropic,
|
||||
Models: []catwalk.Model{{
|
||||
ID: "claude-3-sonnet",
|
||||
}},
|
||||
},
|
||||
}),
|
||||
}
|
||||
cfg.setDefaults("/tmp")
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Providers.Len(), 1)
|
||||
customProvider, exists := cfg.Providers.Get("custom-anthropic")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "custom-anthropic", customProvider.ID)
|
||||
assert.Equal(t, "test-key", customProvider.APIKey)
|
||||
assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
|
||||
assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
|
||||
})
|
||||
|
||||
t.Run("disabled custom provider is removed", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: csync.NewMapFrom(map[string]ProviderConfig{
|
||||
|
||||
@@ -35,34 +35,120 @@ func NewShellVariableResolver(env env.Env) VariableResolver {
|
||||
}
|
||||
|
||||
// ResolveValue is a method for resolving values, such as environment variables.
|
||||
// it will expect strings that start with `$` to be resolved as environment variables or shell commands.
|
||||
// if the string does not start with `$`, it will return the string as is.
|
||||
// it will resolve shell-like variable substitution anywhere in the string, including:
|
||||
// - $(command) for command substitution
|
||||
// - $VAR or ${VAR} for environment variables
|
||||
func (r *shellVariableResolver) ResolveValue(value string) (string, error) {
|
||||
if !strings.HasPrefix(value, "$") {
|
||||
// Special case: lone $ is an error (backward compatibility)
|
||||
if value == "$" {
|
||||
return "", fmt.Errorf("invalid value format: %s", value)
|
||||
}
|
||||
|
||||
// If no $ found, return as-is
|
||||
if !strings.Contains(value, "$") {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") {
|
||||
command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")")
|
||||
result := value
|
||||
|
||||
// Handle command substitution: $(command)
|
||||
for {
|
||||
start := strings.Index(result, "$(")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// Find matching closing parenthesis
|
||||
depth := 0
|
||||
end := -1
|
||||
for i := start + 2; i < len(result); i++ {
|
||||
if result[i] == '(' {
|
||||
depth++
|
||||
} else if result[i] == ')' {
|
||||
if depth == 0 {
|
||||
end = i
|
||||
break
|
||||
}
|
||||
depth--
|
||||
}
|
||||
}
|
||||
|
||||
if end == -1 {
|
||||
return "", fmt.Errorf("unmatched $( in value: %s", value)
|
||||
}
|
||||
|
||||
command := result[start+2 : end]
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
stdout, _, err := r.shell.Exec(ctx, command)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("command execution failed: %w", err)
|
||||
return "", fmt.Errorf("command execution failed for '%s': %w", command, err)
|
||||
}
|
||||
return strings.TrimSpace(stdout), nil
|
||||
|
||||
// Replace the $(command) with the output
|
||||
replacement := strings.TrimSpace(stdout)
|
||||
result = result[:start] + replacement + result[end+1:]
|
||||
}
|
||||
|
||||
if after, ok := strings.CutPrefix(value, "$"); ok {
|
||||
varName := after
|
||||
value = r.env.Get(varName)
|
||||
if value == "" {
|
||||
// Handle environment variables: $VAR and ${VAR}
|
||||
searchStart := 0
|
||||
for {
|
||||
start := strings.Index(result[searchStart:], "$")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
start += searchStart // Adjust for the offset
|
||||
|
||||
// Skip if this is part of $( which we already handled
|
||||
if start+1 < len(result) && result[start+1] == '(' {
|
||||
// Skip past this $(...)
|
||||
searchStart = start + 1
|
||||
continue
|
||||
}
|
||||
var varName string
|
||||
var end int
|
||||
|
||||
if start+1 < len(result) && result[start+1] == '{' {
|
||||
// Handle ${VAR} format
|
||||
closeIdx := strings.Index(result[start+2:], "}")
|
||||
if closeIdx == -1 {
|
||||
return "", fmt.Errorf("unmatched ${ in value: %s", value)
|
||||
}
|
||||
varName = result[start+2 : start+2+closeIdx]
|
||||
end = start + 2 + closeIdx + 1
|
||||
} else {
|
||||
// Handle $VAR format - variable names must start with letter or underscore
|
||||
if start+1 >= len(result) {
|
||||
return "", fmt.Errorf("incomplete variable reference at end of string: %s", value)
|
||||
}
|
||||
|
||||
if result[start+1] != '_' &&
|
||||
(result[start+1] < 'a' || result[start+1] > 'z') &&
|
||||
(result[start+1] < 'A' || result[start+1] > 'Z') {
|
||||
return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value)
|
||||
}
|
||||
|
||||
end = start + 1
|
||||
for end < len(result) && (result[end] == '_' ||
|
||||
(result[end] >= 'a' && result[end] <= 'z') ||
|
||||
(result[end] >= 'A' && result[end] <= 'Z') ||
|
||||
(result[end] >= '0' && result[end] <= '9')) {
|
||||
end++
|
||||
}
|
||||
varName = result[start+1 : end]
|
||||
}
|
||||
|
||||
envValue := r.env.Get(varName)
|
||||
if envValue == "" {
|
||||
return "", fmt.Errorf("environment variable %q not set", varName)
|
||||
}
|
||||
return value, nil
|
||||
|
||||
result = result[:start] + envValue + result[end:]
|
||||
searchStart = start + len(envValue) // Continue searching after the replacement
|
||||
}
|
||||
return "", fmt.Errorf("invalid value format: %s", value)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type environmentVariableResolver struct {
|
||||
|
||||
@@ -47,17 +47,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
|
||||
envVars: map[string]string{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "shell command execution",
|
||||
value: "$(echo hello)",
|
||||
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
|
||||
if command == "echo hello" {
|
||||
return "hello\n", "", nil
|
||||
}
|
||||
return "", "", errors.New("unexpected command")
|
||||
},
|
||||
expected: "hello",
|
||||
},
|
||||
|
||||
{
|
||||
name: "shell command with whitespace trimming",
|
||||
value: "$(echo ' spaced ')",
|
||||
@@ -104,6 +94,171 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
envVars map[string]string
|
||||
shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error)
|
||||
expected string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "command substitution within string",
|
||||
value: "Bearer $(echo token123)",
|
||||
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
|
||||
if command == "echo token123" {
|
||||
return "token123\n", "", nil
|
||||
}
|
||||
return "", "", errors.New("unexpected command")
|
||||
},
|
||||
expected: "Bearer token123",
|
||||
},
|
||||
{
|
||||
name: "environment variable within string",
|
||||
value: "Bearer $TOKEN",
|
||||
envVars: map[string]string{"TOKEN": "sk-ant-123"},
|
||||
expected: "Bearer sk-ant-123",
|
||||
},
|
||||
{
|
||||
name: "environment variable with braces within string",
|
||||
value: "Bearer ${TOKEN}",
|
||||
envVars: map[string]string{"TOKEN": "sk-ant-456"},
|
||||
expected: "Bearer sk-ant-456",
|
||||
},
|
||||
{
|
||||
name: "mixed command and environment substitution",
|
||||
value: "$USER-$(date +%Y)-$HOST",
|
||||
envVars: map[string]string{
|
||||
"USER": "testuser",
|
||||
"HOST": "localhost",
|
||||
},
|
||||
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
|
||||
if command == "date +%Y" {
|
||||
return "2024\n", "", nil
|
||||
}
|
||||
return "", "", errors.New("unexpected command")
|
||||
},
|
||||
expected: "testuser-2024-localhost",
|
||||
},
|
||||
{
|
||||
name: "multiple command substitutions",
|
||||
value: "$(echo hello) $(echo world)",
|
||||
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
|
||||
switch command {
|
||||
case "echo hello":
|
||||
return "hello\n", "", nil
|
||||
case "echo world":
|
||||
return "world\n", "", nil
|
||||
}
|
||||
return "", "", errors.New("unexpected command")
|
||||
},
|
||||
expected: "hello world",
|
||||
},
|
||||
{
|
||||
name: "nested parentheses in command",
|
||||
value: "$(echo $(echo inner))",
|
||||
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
|
||||
if command == "echo $(echo inner)" {
|
||||
return "nested\n", "", nil
|
||||
}
|
||||
return "", "", errors.New("unexpected command")
|
||||
},
|
||||
expected: "nested",
|
||||
},
|
||||
{
|
||||
name: "lone dollar with non-variable chars",
|
||||
value: "prefix$123suffix", // Numbers can't start variable names
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "dollar with special chars",
|
||||
value: "a$@b$#c", // Special chars aren't valid in variable names
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty environment variable substitution",
|
||||
value: "Bearer $EMPTY_VAR",
|
||||
envVars: map[string]string{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "unmatched command substitution opening",
|
||||
value: "Bearer $(echo test",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "unmatched environment variable braces",
|
||||
value: "Bearer ${TOKEN",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "command substitution with error",
|
||||
value: "Bearer $(false)",
|
||||
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
|
||||
return "", "", errors.New("command failed")
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "complex real-world example",
|
||||
value: "Bearer $(cat /tmp/token.txt | base64 -w 0)",
|
||||
shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) {
|
||||
if command == "cat /tmp/token.txt | base64 -w 0" {
|
||||
return "c2stYW50LXRlc3Q=\n", "", nil
|
||||
}
|
||||
return "", "", errors.New("unexpected command")
|
||||
},
|
||||
expected: "Bearer c2stYW50LXRlc3Q=",
|
||||
},
|
||||
{
|
||||
name: "environment variable with underscores and numbers",
|
||||
value: "Bearer $API_KEY_V2",
|
||||
envVars: map[string]string{"API_KEY_V2": "sk-test-123"},
|
||||
expected: "Bearer sk-test-123",
|
||||
},
|
||||
{
|
||||
name: "no substitution needed",
|
||||
value: "Bearer sk-ant-static-token",
|
||||
expected: "Bearer sk-ant-static-token",
|
||||
},
|
||||
{
|
||||
name: "incomplete variable at end",
|
||||
value: "Bearer $",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "variable with invalid character",
|
||||
value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "multiple invalid variables",
|
||||
value: "$1$2$3",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testEnv := env.NewFromMap(tt.envVars)
|
||||
resolver := &shellVariableResolver{
|
||||
shell: &mockShell{execFunc: tt.shellFunc},
|
||||
env: testEnv,
|
||||
}
|
||||
|
||||
result, err := resolver.ResolveValue(tt.value)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -39,8 +39,30 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl
|
||||
|
||||
func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
|
||||
anthropicClientOptions := []option.RequestOption{}
|
||||
if opts.apiKey != "" {
|
||||
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
|
||||
|
||||
// Check if Authorization header is provided in extra headers
|
||||
hasBearerAuth := false
|
||||
if opts.extraHeaders != nil {
|
||||
for key := range opts.extraHeaders {
|
||||
if strings.ToLower(key) == "authorization" {
|
||||
hasBearerAuth = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ")
|
||||
|
||||
if opts.apiKey != "" && !hasBearerAuth {
|
||||
if isBearerToken {
|
||||
slog.Debug("API key starts with 'Bearer ', using as Authorization header")
|
||||
anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey))
|
||||
} else {
|
||||
// Use standard X-Api-Key header
|
||||
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
|
||||
}
|
||||
} else if hasBearerAuth {
|
||||
slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
|
||||
}
|
||||
if useBedrock {
|
||||
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
|
||||
@@ -200,6 +222,25 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
|
||||
maxTokens = int64(a.adjustedMaxTokens)
|
||||
}
|
||||
|
||||
systemBlocks := []anthropic.TextBlockParam{}
|
||||
|
||||
// Add custom system prompt prefix if configured
|
||||
if a.providerOptions.systemPromptPrefix != "" {
|
||||
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
|
||||
Text: a.providerOptions.systemPromptPrefix,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
|
||||
Text: a.providerOptions.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
})
|
||||
|
||||
return anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(model.ID),
|
||||
MaxTokens: maxTokens,
|
||||
@@ -207,14 +248,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
Thinking: thinkingParam,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.providerOptions.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
System: systemBlocks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,6 +427,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
retry, after, retryErr := a.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
|
||||
@@ -180,12 +180,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
if modelConfig.MaxTokens > 0 {
|
||||
maxTokens = modelConfig.MaxTokens
|
||||
}
|
||||
systemMessage := g.providerOptions.systemMessage
|
||||
if g.providerOptions.systemPromptPrefix != "" {
|
||||
systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
|
||||
}
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
config := &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
Parts: []*genai.Part{{Text: systemMessage}},
|
||||
},
|
||||
}
|
||||
config.Tools = g.convertTools(tools)
|
||||
@@ -280,12 +284,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
if g.providerOptions.maxTokens > 0 {
|
||||
maxTokens = g.providerOptions.maxTokens
|
||||
}
|
||||
systemMessage := g.providerOptions.systemMessage
|
||||
if g.providerOptions.systemPromptPrefix != "" {
|
||||
systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage
|
||||
}
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
config := &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
Parts: []*genai.Part{{Text: systemMessage}},
|
||||
},
|
||||
}
|
||||
config.Tools = g.convertTools(tools)
|
||||
|
||||
@@ -57,7 +57,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
|
||||
|
||||
func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
|
||||
// Add system message first
|
||||
openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
|
||||
systemMessage := o.providerOptions.systemMessage
|
||||
if o.providerOptions.systemPromptPrefix != "" {
|
||||
systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
|
||||
}
|
||||
openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
|
||||
@@ -61,17 +61,18 @@ type Provider interface {
|
||||
}
|
||||
|
||||
type providerClientOptions struct {
|
||||
baseURL string
|
||||
config config.ProviderConfig
|
||||
apiKey string
|
||||
modelType config.SelectedModelType
|
||||
model func(config.SelectedModelType) catwalk.Model
|
||||
disableCache bool
|
||||
systemMessage string
|
||||
maxTokens int64
|
||||
extraHeaders map[string]string
|
||||
extraBody map[string]any
|
||||
extraParams map[string]string
|
||||
baseURL string
|
||||
config config.ProviderConfig
|
||||
apiKey string
|
||||
modelType config.SelectedModelType
|
||||
model func(config.SelectedModelType) catwalk.Model
|
||||
disableCache bool
|
||||
systemMessage string
|
||||
systemPromptPrefix string
|
||||
maxTokens int64
|
||||
extraHeaders map[string]string
|
||||
extraBody map[string]any
|
||||
extraParams map[string]string
|
||||
}
|
||||
|
||||
type ProviderClientOption func(*providerClientOptions)
|
||||
@@ -143,12 +144,23 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
|
||||
}
|
||||
|
||||
// Resolve extra headers
|
||||
resolvedExtraHeaders := make(map[string]string)
|
||||
for key, value := range cfg.ExtraHeaders {
|
||||
resolvedValue, err := config.Get().Resolve(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
|
||||
}
|
||||
resolvedExtraHeaders[key] = resolvedValue
|
||||
}
|
||||
|
||||
clientOptions := providerClientOptions{
|
||||
baseURL: cfg.BaseURL,
|
||||
config: cfg,
|
||||
apiKey: resolvedAPIKey,
|
||||
extraHeaders: cfg.ExtraHeaders,
|
||||
extraBody: cfg.ExtraBody,
|
||||
baseURL: cfg.BaseURL,
|
||||
config: cfg,
|
||||
apiKey: resolvedAPIKey,
|
||||
extraHeaders: resolvedExtraHeaders,
|
||||
extraBody: cfg.ExtraBody,
|
||||
systemPromptPrefix: cfg.SystemPromptPrefix,
|
||||
model: func(tp config.SelectedModelType) catwalk.Model {
|
||||
return *config.Get().GetModelByType(tp)
|
||||
},
|
||||
|
||||
@@ -161,10 +161,17 @@ func (m *editorCmp) send() tea.Cmd {
|
||||
)
|
||||
}
|
||||
|
||||
func (m *editorCmp) repositionCompletions() tea.Msg {
|
||||
x, y := m.completionsPosition()
|
||||
return completions.RepositionCompletionsMsg{X: x, Y: y}
|
||||
}
|
||||
|
||||
func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
var cmds []tea.Cmd
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
return m, m.repositionCompletions
|
||||
case filepicker.FilePickedMsg:
|
||||
if len(m.attachments) >= maxAttachments {
|
||||
return m, util.ReportError(fmt.Errorf("cannot add more than %d images", maxAttachments))
|
||||
@@ -182,32 +189,37 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil
|
||||
}
|
||||
if item, ok := msg.Value.(FileCompletionItem); ok {
|
||||
word := m.textarea.Word()
|
||||
// If the selected item is a file, insert its path into the textarea
|
||||
value := m.textarea.Value()
|
||||
value = value[:m.completionsStartIndex]
|
||||
value += item.Path
|
||||
value = value[:m.completionsStartIndex] + // Remove the current query
|
||||
item.Path + // Insert the file path
|
||||
value[m.completionsStartIndex+len(word):] // Append the rest of the value
|
||||
// XXX: This will always move the cursor to the end of the textarea.
|
||||
m.textarea.SetValue(value)
|
||||
m.textarea.MoveToEnd()
|
||||
if !msg.Insert {
|
||||
m.isCompletionsOpen = false
|
||||
m.currentQuery = ""
|
||||
m.completionsStartIndex = 0
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
case openEditorMsg:
|
||||
m.textarea.SetValue(msg.Text)
|
||||
m.textarea.MoveToEnd()
|
||||
case tea.KeyPressMsg:
|
||||
cur := m.textarea.Cursor()
|
||||
curIdx := m.textarea.Width()*cur.Y + cur.X
|
||||
switch {
|
||||
// Completions
|
||||
case msg.String() == "/" && !m.isCompletionsOpen &&
|
||||
// only show if beginning of prompt, or if previous char is a space:
|
||||
(len(m.textarea.Value()) == 0 || m.textarea.Value()[len(m.textarea.Value())-1] == ' '):
|
||||
// only show if beginning of prompt, or if previous char is a space or newline:
|
||||
(len(m.textarea.Value()) == 0 || unicode.IsSpace(rune(m.textarea.Value()[len(m.textarea.Value())-1]))):
|
||||
m.isCompletionsOpen = true
|
||||
m.currentQuery = ""
|
||||
m.completionsStartIndex = len(m.textarea.Value())
|
||||
m.completionsStartIndex = curIdx
|
||||
cmds = append(cmds, m.startCompletions)
|
||||
case m.isCompletionsOpen && m.textarea.Cursor().X <= m.completionsStartIndex:
|
||||
case m.isCompletionsOpen && curIdx <= m.completionsStartIndex:
|
||||
cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{}))
|
||||
}
|
||||
if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) {
|
||||
@@ -244,6 +256,7 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
if key.Matches(msg, m.keyMap.Newline) {
|
||||
m.textarea.InsertRune('\n')
|
||||
cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{}))
|
||||
}
|
||||
// Handle Enter key
|
||||
if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) {
|
||||
@@ -275,12 +288,18 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// XXX: wont' work if editing in the middle of the field.
|
||||
m.completionsStartIndex = strings.LastIndex(m.textarea.Value(), word)
|
||||
m.currentQuery = word[1:]
|
||||
x, y := m.completionsPosition()
|
||||
x -= len(m.currentQuery)
|
||||
m.isCompletionsOpen = true
|
||||
cmds = append(cmds, util.CmdHandler(completions.FilterCompletionsMsg{
|
||||
Query: m.currentQuery,
|
||||
Reopen: m.isCompletionsOpen,
|
||||
}))
|
||||
} else {
|
||||
cmds = append(cmds,
|
||||
util.CmdHandler(completions.FilterCompletionsMsg{
|
||||
Query: m.currentQuery,
|
||||
Reopen: m.isCompletionsOpen,
|
||||
X: x,
|
||||
Y: y,
|
||||
}),
|
||||
)
|
||||
} else if m.isCompletionsOpen {
|
||||
m.isCompletionsOpen = false
|
||||
m.currentQuery = ""
|
||||
m.completionsStartIndex = 0
|
||||
@@ -293,6 +312,16 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *editorCmp) completionsPosition() (int, int) {
|
||||
cur := m.textarea.Cursor()
|
||||
if cur == nil {
|
||||
return m.x, m.y + 1 // adjust for padding
|
||||
}
|
||||
x := cur.X + m.x
|
||||
y := cur.Y + m.y + 1 // adjust for padding
|
||||
return x, y
|
||||
}
|
||||
|
||||
func (m *editorCmp) Cursor() *tea.Cursor {
|
||||
cursor := m.textarea.Cursor()
|
||||
if cursor != nil {
|
||||
@@ -373,9 +402,7 @@ func (m *editorCmp) startCompletions() tea.Msg {
|
||||
})
|
||||
}
|
||||
|
||||
cur := m.textarea.Cursor()
|
||||
x := cur.X + m.x // adjust for padding
|
||||
y := cur.Y + m.y + 1
|
||||
x, y := m.completionsPosition()
|
||||
return completions.OpenCompletionsMsg{
|
||||
Completions: completionItems,
|
||||
X: x,
|
||||
|
||||
@@ -27,6 +27,12 @@ type OpenCompletionsMsg struct {
|
||||
type FilterCompletionsMsg struct {
|
||||
Query string // The query to filter completions
|
||||
Reopen bool
|
||||
X int // X position for the completions popup
|
||||
Y int // Y position for the completions popup
|
||||
}
|
||||
|
||||
type RepositionCompletionsMsg struct {
|
||||
X, Y int
|
||||
}
|
||||
|
||||
type CompletionsClosedMsg struct{}
|
||||
@@ -53,18 +59,24 @@ type Completions interface {
|
||||
type listModel = list.FilterableList[list.CompletionItem[any]]
|
||||
|
||||
type completionsCmp struct {
|
||||
width int
|
||||
height int // Height of the completions component`
|
||||
x int // X position for the completions popup
|
||||
y int // Y position for the completions popup
|
||||
open bool // Indicates if the completions are open
|
||||
keyMap KeyMap
|
||||
wWidth int // The window width
|
||||
wHeight int // The window height
|
||||
width int
|
||||
lastWidth int
|
||||
height int // Height of the completions component`
|
||||
x, xorig int // X position for the completions popup
|
||||
y int // Y position for the completions popup
|
||||
open bool // Indicates if the completions are open
|
||||
keyMap KeyMap
|
||||
|
||||
list listModel
|
||||
query string // The current filter query
|
||||
}
|
||||
|
||||
const maxCompletionsWidth = 80 // Maximum width for the completions popup
|
||||
const (
|
||||
maxCompletionsWidth = 80 // Maximum width for the completions popup
|
||||
minCompletionsWidth = 20 // Minimum width for the completions popup
|
||||
)
|
||||
|
||||
func New() Completions {
|
||||
completionsKeyMap := DefaultKeyMap()
|
||||
@@ -88,7 +100,7 @@ func New() Completions {
|
||||
)
|
||||
return &completionsCmp{
|
||||
width: 0,
|
||||
height: 0,
|
||||
height: maxCompletionsHeight,
|
||||
list: l,
|
||||
query: "",
|
||||
keyMap: completionsKeyMap,
|
||||
@@ -107,8 +119,7 @@ func (c *completionsCmp) Init() tea.Cmd {
|
||||
func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
c.width = min(msg.Width-c.x, maxCompletionsWidth)
|
||||
c.height = min(msg.Height-c.y, 15)
|
||||
c.wWidth, c.wHeight = msg.Width, msg.Height
|
||||
return c, nil
|
||||
case tea.KeyPressMsg:
|
||||
switch {
|
||||
@@ -156,13 +167,16 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case key.Matches(msg, c.keyMap.Cancel):
|
||||
return c, util.CmdHandler(CloseCompletionsMsg{})
|
||||
}
|
||||
case RepositionCompletionsMsg:
|
||||
c.x, c.y = msg.X, msg.Y
|
||||
c.adjustPosition()
|
||||
case CloseCompletionsMsg:
|
||||
c.open = false
|
||||
return c, util.CmdHandler(CompletionsClosedMsg{})
|
||||
case OpenCompletionsMsg:
|
||||
c.open = true
|
||||
c.query = ""
|
||||
c.x = msg.X
|
||||
c.x, c.xorig = msg.X, msg.X
|
||||
c.y = msg.Y
|
||||
items := []list.CompletionItem[any]{}
|
||||
t := styles.CurrentTheme()
|
||||
@@ -174,10 +188,18 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
)
|
||||
items = append(items, item)
|
||||
}
|
||||
c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height
|
||||
width := listWidth(items)
|
||||
if len(items) == 0 {
|
||||
width = listWidth(c.list.Items())
|
||||
}
|
||||
if c.x+width >= c.wWidth {
|
||||
c.x = c.wWidth - width - 1
|
||||
}
|
||||
c.width = width
|
||||
c.height = max(min(maxCompletionsHeight, len(items)), 1) // Ensure at least 1 item height
|
||||
return c, tea.Batch(
|
||||
c.list.SetSize(c.width, c.height),
|
||||
c.list.SetItems(items),
|
||||
c.list.SetSize(c.width, c.height),
|
||||
util.CmdHandler(CompletionsOpenedMsg{}),
|
||||
)
|
||||
case FilterCompletionsMsg:
|
||||
@@ -201,8 +223,11 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
c.query = msg.Query
|
||||
var cmds []tea.Cmd
|
||||
cmds = append(cmds, c.list.Filter(msg.Query))
|
||||
itemsLen := len(c.list.Items())
|
||||
c.height = max(min(maxCompletionsHeight, itemsLen), 1)
|
||||
items := c.list.Items()
|
||||
itemsLen := len(items)
|
||||
c.xorig = msg.X
|
||||
c.x, c.y = msg.X, msg.Y
|
||||
c.adjustPosition()
|
||||
cmds = append(cmds, c.list.SetSize(c.width, c.height))
|
||||
if itemsLen == 0 {
|
||||
cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{}))
|
||||
@@ -215,21 +240,54 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *completionsCmp) adjustPosition() {
|
||||
items := c.list.Items()
|
||||
itemsLen := len(items)
|
||||
width := listWidth(items)
|
||||
c.lastWidth = c.width
|
||||
if c.x < 0 || width < c.lastWidth {
|
||||
c.x = c.xorig
|
||||
} else if c.x+width >= c.wWidth {
|
||||
c.x = c.wWidth - width - 1
|
||||
}
|
||||
c.width = width
|
||||
c.height = max(min(maxCompletionsHeight, itemsLen), 1)
|
||||
}
|
||||
|
||||
// View implements Completions.
|
||||
func (c *completionsCmp) View() string {
|
||||
if !c.open || len(c.list.Items()) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return c.style().Render(c.list.View())
|
||||
}
|
||||
|
||||
func (c *completionsCmp) style() lipgloss.Style {
|
||||
t := styles.CurrentTheme()
|
||||
return t.S().Base.
|
||||
style := t.S().Base.
|
||||
Width(c.width).
|
||||
Height(c.height).
|
||||
Background(t.BgSubtle)
|
||||
|
||||
return style.Render(c.list.View())
|
||||
}
|
||||
|
||||
// listWidth returns the width of the last 10 items in the list, which is used
|
||||
// to determine the width of the completions popup.
|
||||
// Note this only works for [completionItemCmp] items.
|
||||
func listWidth[T any](items []T) int {
|
||||
var width int
|
||||
if len(items) == 0 {
|
||||
return width
|
||||
}
|
||||
|
||||
for i := len(items) - 1; i >= 0 && i >= len(items)-10; i-- {
|
||||
item, ok := any(items[i]).(*completionItemCmp)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemWidth := lipgloss.Width(item.text) + 2 // +2 for padding
|
||||
width = max(width, itemWidth)
|
||||
}
|
||||
|
||||
return width
|
||||
}
|
||||
|
||||
func (c *completionsCmp) Open() bool {
|
||||
|
||||
@@ -172,7 +172,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
return p, nil
|
||||
case tea.WindowSizeMsg:
|
||||
return p, p.SetSize(msg.Width, msg.Height)
|
||||
u, cmd := p.editor.Update(msg)
|
||||
p.editor = u.(editor.Editor)
|
||||
return p, tea.Batch(p.SetSize(msg.Width, msg.Height), cmd)
|
||||
case CancelTimerExpiredMsg:
|
||||
p.isCanceling = false
|
||||
return p, nil
|
||||
|
||||
@@ -111,19 +111,10 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return a, a.handleWindowResize(msg.Width, msg.Height)
|
||||
|
||||
// Completions messages
|
||||
case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg:
|
||||
case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg,
|
||||
completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg:
|
||||
u, completionCmd := a.completions.Update(msg)
|
||||
a.completions = u.(completions.Completions)
|
||||
switch msg := msg.(type) {
|
||||
case completions.OpenCompletionsMsg:
|
||||
x, _ := a.completions.Position()
|
||||
if a.completions.Width()+x >= a.wWidth {
|
||||
// Adjust X position to fit in the window.
|
||||
msg.X = a.wWidth - a.completions.Width() - 1
|
||||
u, completionCmd = a.completions.Update(msg)
|
||||
a.completions = u.(completions.Completions)
|
||||
}
|
||||
}
|
||||
return a, completionCmd
|
||||
|
||||
// Dialog messages
|
||||
|
||||
Reference in New Issue
Block a user