chore: move to the new config

This commit is contained in:
Kujtim Hoxha
2025-06-27 10:55:44 +02:00
parent 6e0feda9c4
commit 565ab85eb9
56 changed files with 1139 additions and 3280 deletions

View File

@@ -72,7 +72,8 @@ to assist developers in writing, debugging, and understanding code directly from
}
cwd = c
}
_, err := config.Load(cwd, debug)
_, err := config.Init(cwd, debug)
if err != nil {
return err
}

View File

@@ -1,3 +1,4 @@
// TODO: FIX THIS
package main
import (
@@ -6,7 +7,6 @@ import (
"os"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
)
// JSONSchemaType represents a JSON Schema type
@@ -192,22 +192,10 @@ func generateSchema() map[string]any {
},
}
// Add known providers
knownProviders := []string{
string(models.ProviderAnthropic),
string(models.ProviderOpenAI),
string(models.ProviderGemini),
string(models.ProviderGROQ),
string(models.ProviderOpenRouter),
string(models.ProviderBedrock),
string(models.ProviderAzure),
string(models.ProviderVertexAI),
}
providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{
"type": "string",
"description": "Provider type",
"enum": knownProviders,
"enum": []string{},
}
schema["properties"].(map[string]any)["providers"] = providerSchema
@@ -241,9 +229,7 @@ func generateSchema() map[string]any {
// Add model enum
modelEnum := []string{}
for modelID := range models.SupportedModels {
modelEnum = append(modelEnum, string(modelID))
}
agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum
// Add specific agent properties
@@ -251,7 +237,6 @@ func generateSchema() map[string]any {
knownAgents := []string{
string(config.AgentCoder),
string(config.AgentTask),
string(config.AgentTitle),
}
for _, agentName := range knownAgents {

View File

@@ -9,7 +9,7 @@ import (
"sync"
"time"
"github.com/charmbracelet/crush/internal/config"
configv2 "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/history"
@@ -55,18 +55,21 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
// Initialize LSP clients in the background
go app.initLSPClients(ctx)
cfg := configv2.Get()
coderAgentCfg := cfg.Agents[configv2.AgentCoder]
if coderAgentCfg.ID == "" {
return nil, fmt.Errorf("coder agent configuration is missing")
}
var err error
app.CoderAgent, err = agent.NewAgent(
config.AgentCoder,
coderAgentCfg,
app.Permissions,
app.Sessions,
app.Messages,
agent.CoderAgentTools(
app.Permissions,
app.Sessions,
app.Messages,
app.History,
app.LSPClients,
),
app.History,
app.LSPClients,
)
if err != nil {
logging.Error("Failed to create coder agent", err)

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
package configv2
package config
import (
"encoding/json"
@@ -28,7 +28,7 @@ func TestConfigWithEnv(t *testing.T) {
os.Setenv("GEMINI_API_KEY", "test-gemini-key")
os.Setenv("XAI_API_KEY", "test-xai-key")
os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
cfg := InitConfig(cwdDir)
cfg, _ := Init(cwdDir, false)
data, _ := json.MarshalIndent(cfg, "", " ")
fmt.Println(string(data))
assert.Len(t, cfg.Providers, 5)

View File

@@ -1,4 +1,4 @@
package configv2
package config
import (
"fmt"

View File

@@ -17,23 +17,20 @@ type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
// ShouldShowInitDialog checks if the initialization dialog should be shown for the current directory
func ShouldShowInitDialog() (bool, error) {
if cfg == nil {
// ProjectNeedsInitialization checks if the current project needs initialization
func ProjectNeedsInitialization() (bool, error) {
if instance == nil {
return false, fmt.Errorf("config not loaded")
}
// Create the flag file path
flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
// Check if the flag file exists
_, err := os.Stat(flagFilePath)
if err == nil {
// File exists, don't show the dialog
return false, nil
}
// If the error is not "file not found", return the error
if !os.IsNotExist(err) {
return false, fmt.Errorf("failed to check init flag file: %w", err)
}
@@ -44,11 +41,9 @@ func ShouldShowInitDialog() (bool, error) {
return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err)
}
if crushExists {
// CRUSH.md already exists, don't show the dialog
return false, nil
}
// File doesn't exist, show the dialog
return true, nil
}
@@ -75,13 +70,11 @@ func crushMdExists(dir string) (bool, error) {
// MarkProjectInitialized marks the current project as initialized
func MarkProjectInitialized() error {
if cfg == nil {
if instance == nil {
return fmt.Errorf("config not loaded")
}
// Create the flag file path
flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
// Create an empty file to mark the project as initialized
file, err := os.Create(flagFilePath)
if err != nil {
return fmt.Errorf("failed to create init flag file: %w", err)

View File

@@ -1,4 +1,4 @@
package configv2
package config
import (
"encoding/json"

View File

@@ -1,660 +0,0 @@
package configv2
import (
"encoding/json"
"errors"
"maps"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/logging"
)
const (
defaultDataDirectory = ".crush"
defaultLogLevel = "info"
appName = "crush"
MaxTokensFallbackDefault = 4096
)
var defaultContextPaths = []string{
".github/copilot-instructions.md",
".cursorrules",
".cursor/rules/",
"CLAUDE.md",
"CLAUDE.local.md",
"crush.md",
"crush.local.md",
"Crush.md",
"Crush.local.md",
"CRUSH.md",
"CRUSH.local.md",
}
type AgentID string
const (
AgentCoder AgentID = "coder"
AgentTask AgentID = "task"
AgentTitle AgentID = "title"
AgentSummarize AgentID = "summarize"
)
type Model struct {
ID string `json:"id"`
Name string `json:"model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
ContextWindow int64 `json:"context_window"`
DefaultMaxTokens int64 `json:"default_max_tokens"`
CanReason bool `json:"can_reason"`
ReasoningEffort string `json:"reasoning_effort"`
SupportsImages bool `json:"supports_attachments"`
}
type VertexAIOptions struct {
APIKey string `json:"api_key,omitempty"`
Project string `json:"project,omitempty"`
Location string `json:"location,omitempty"`
}
type ProviderConfig struct {
ID provider.InferenceProvider `json:"id"`
BaseURL string `json:"base_url,omitempty"`
ProviderType provider.Type `json:"provider_type"`
APIKey string `json:"api_key,omitempty"`
Disabled bool `json:"disabled"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
// used for e.x for vertex to set the project
ExtraParams map[string]string `json:"extra_params,omitempty"`
DefaultLargeModel string `json:"default_large_model,omitempty"`
DefaultSmallModel string `json:"default_small_model,omitempty"`
Models []Model `json:"models,omitempty"`
}
type Agent struct {
ID AgentID `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
// This is the id of the system prompt used by the agent
Disabled bool `json:"disabled"`
Provider provider.InferenceProvider `json:"provider"`
Model string `json:"model"`
// The available tools for the agent
// if this is nil, all tools are available
AllowedTools []string `json:"allowed_tools"`
// this tells us which MCPs are available for this agent
// if this is empty all mcps are available
// the string array is the list of tools from the AllowedMCP the agent has available
// if the string array is nil, all tools from the AllowedMCP are available
AllowedMCP map[string][]string `json:"allowed_mcp"`
// The list of LSPs that this agent can use
// if this is nil, all LSPs are available
AllowedLSP []string `json:"allowed_lsp"`
// Overrides the context paths for this agent
ContextPaths []string `json:"context_paths"`
}
type MCPType string
const (
MCPStdio MCPType = "stdio"
MCPSse MCPType = "sse"
)
type MCP struct {
Command string `json:"command"`
Env []string `json:"env"`
Args []string `json:"args"`
Type MCPType `json:"type"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}
type LSPConfig struct {
Disabled bool `json:"enabled"`
Command string `json:"command"`
Args []string `json:"args"`
Options any `json:"options"`
}
type TUIOptions struct {
CompactMode bool `json:"compact_mode"`
// Here we can add themes later or any TUI related options
}
type Options struct {
ContextPaths []string `json:"context_paths"`
TUI TUIOptions `json:"tui"`
Debug bool `json:"debug"`
DebugLSP bool `json:"debug_lsp"`
DisableAutoSummarize bool `json:"disable_auto_summarize"`
// Relative to the cwd
DataDirectory string `json:"data_directory"`
}
type Config struct {
// List of configured providers
Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
// List of configured agents
Agents map[AgentID]Agent `json:"agents,omitempty"`
// List of configured MCPs
MCP map[string]MCP `json:"mcp,omitempty"`
// List of configured LSPs
LSP map[string]LSPConfig `json:"lsp,omitempty"`
// Miscellaneous options
Options Options `json:"options"`
}
var (
instance *Config // The single instance of the Singleton
cwd string
once sync.Once // Ensures the initialization happens only once
)
func loadConfig(cwd string) (*Config, error) {
// First read the global config file
cfgPath := ConfigPath()
cfg := defaultConfigBasedOnEnv()
var globalCfg *Config
if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
// some other error occurred while checking the file
return nil, err
} else if err == nil {
// config file exists, read it
file, err := os.ReadFile(cfgPath)
if err != nil {
return nil, err
}
globalCfg = &Config{}
if err := json.Unmarshal(file, globalCfg); err != nil {
return nil, err
}
} else {
// config file does not exist, create a new one
globalCfg = &Config{}
}
var localConfig *Config
// Global config loaded, now read the local config file
localConfigPath := filepath.Join(cwd, "crush.json")
if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
// some other error occurred while checking the file
return nil, err
} else if err == nil {
// local config file exists, read it
file, err := os.ReadFile(localConfigPath)
if err != nil {
return nil, err
}
localConfig = &Config{}
if err := json.Unmarshal(file, localConfig); err != nil {
return nil, err
}
}
// merge options
mergeOptions(cfg, globalCfg, localConfig)
mergeProviderConfigs(cfg, globalCfg, localConfig)
// no providers found the app is not initialized yet
if len(cfg.Providers) == 0 {
return cfg, nil
}
preferredProvider := getPreferredProvider(cfg.Providers)
if preferredProvider == nil {
return nil, errors.New("no valid providers configured")
}
agents := map[AgentID]Agent{
AgentCoder: {
ID: AgentCoder,
Name: "Coder",
Description: "An agent that helps with executing coding tasks.",
Provider: preferredProvider.ID,
Model: preferredProvider.DefaultLargeModel,
ContextPaths: cfg.Options.ContextPaths,
// All tools allowed
},
AgentTask: {
ID: AgentTask,
Name: "Task",
Description: "An agent that helps with searching for context and finding implementation details.",
Provider: preferredProvider.ID,
Model: preferredProvider.DefaultLargeModel,
ContextPaths: cfg.Options.ContextPaths,
AllowedTools: []string{
"glob",
"grep",
"ls",
"sourcegraph",
"view",
},
// NO MCPs or LSPs by default
AllowedMCP: map[string][]string{},
AllowedLSP: []string{},
},
AgentTitle: {
ID: AgentTitle,
Name: "Title",
Description: "An agent that helps with generating titles for sessions.",
Provider: preferredProvider.ID,
Model: preferredProvider.DefaultSmallModel,
ContextPaths: cfg.Options.ContextPaths,
AllowedTools: []string{},
// NO MCPs or LSPs by default
AllowedMCP: map[string][]string{},
AllowedLSP: []string{},
},
AgentSummarize: {
ID: AgentSummarize,
Name: "Summarize",
Description: "An agent that helps with summarizing sessions.",
Provider: preferredProvider.ID,
Model: preferredProvider.DefaultSmallModel,
ContextPaths: cfg.Options.ContextPaths,
AllowedTools: []string{},
// NO MCPs or LSPs by default
AllowedMCP: map[string][]string{},
AllowedLSP: []string{},
},
}
cfg.Agents = agents
mergeAgents(cfg, globalCfg, localConfig)
mergeMCPs(cfg, globalCfg, localConfig)
mergeLSPs(cfg, globalCfg, localConfig)
return cfg, nil
}
func InitConfig(workingDir string) *Config {
once.Do(func() {
cwd = workingDir
cfg, err := loadConfig(cwd)
if err != nil {
// TODO: Handle this better
panic("Failed to load config: " + err.Error())
}
instance = cfg
})
return instance
}
func GetConfig() *Config {
if instance == nil {
// TODO: Handle this better
panic("Config not initialized. Call InitConfig first.")
}
return instance
}
func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
providers := Providers()
for _, p := range providers {
if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
return &providerConfig
}
}
// if none found return the first configured provider
for _, providerConfig := range configuredProviders {
if !providerConfig.Disabled {
return &providerConfig
}
}
return nil
}
func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
if other.APIKey != "" {
base.APIKey = other.APIKey
}
// Only change these options if the provider is not a known provider
if !slices.Contains(provider.KnownProviders(), p) {
if other.BaseURL != "" {
base.BaseURL = other.BaseURL
}
if other.ProviderType != "" {
base.ProviderType = other.ProviderType
}
if len(base.ExtraHeaders) > 0 {
if base.ExtraHeaders == nil {
base.ExtraHeaders = make(map[string]string)
}
maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
}
if len(other.ExtraParams) > 0 {
if base.ExtraParams == nil {
base.ExtraParams = make(map[string]string)
}
maps.Copy(base.ExtraParams, other.ExtraParams)
}
}
if other.Disabled {
base.Disabled = other.Disabled
}
if other.DefaultLargeModel != "" {
base.DefaultLargeModel = other.DefaultLargeModel
}
// Add new models if they don't exist
if other.Models != nil {
for _, model := range other.Models {
// check if the model already exists
exists := false
for _, existingModel := range base.Models {
if existingModel.ID == model.ID {
exists = true
break
}
}
if !exists {
base.Models = append(base.Models, model)
}
}
}
return base
}
func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
if !slices.Contains(provider.KnownProviders(), p) {
if providerConfig.ProviderType != provider.TypeOpenAI {
return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
}
if providerConfig.BaseURL == "" {
return errors.New("base URL must be set for custom providers")
}
if providerConfig.APIKey == "" {
return errors.New("API key must be set for custom providers")
}
}
return nil
}
func mergeOptions(base, global, local *Config) {
for _, cfg := range []*Config{global, local} {
if cfg == nil {
continue
}
baseOptions := base.Options
other := cfg.Options
if len(other.ContextPaths) > 0 {
baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
}
if other.TUI.CompactMode {
baseOptions.TUI.CompactMode = other.TUI.CompactMode
}
if other.Debug {
baseOptions.Debug = other.Debug
}
if other.DebugLSP {
baseOptions.DebugLSP = other.DebugLSP
}
if other.DisableAutoSummarize {
baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
}
if other.DataDirectory != "" {
baseOptions.DataDirectory = other.DataDirectory
}
base.Options = baseOptions
}
}
func mergeAgents(base, global, local *Config) {
for _, cfg := range []*Config{global, local} {
if cfg == nil {
continue
}
for agentID, newAgent := range cfg.Agents {
if _, ok := base.Agents[agentID]; !ok {
newAgent.ID = agentID // Ensure the ID is set correctly
base.Agents[agentID] = newAgent
} else {
switch agentID {
case AgentCoder:
baseAgent := base.Agents[agentID]
baseAgent.Model = newAgent.Model
baseAgent.Provider = newAgent.Provider
baseAgent.AllowedMCP = newAgent.AllowedMCP
baseAgent.AllowedLSP = newAgent.AllowedLSP
base.Agents[agentID] = baseAgent
case AgentTask:
baseAgent := base.Agents[agentID]
baseAgent.Model = newAgent.Model
baseAgent.Provider = newAgent.Provider
base.Agents[agentID] = baseAgent
case AgentTitle:
baseAgent := base.Agents[agentID]
baseAgent.Model = newAgent.Model
baseAgent.Provider = newAgent.Provider
base.Agents[agentID] = baseAgent
case AgentSummarize:
baseAgent := base.Agents[agentID]
baseAgent.Model = newAgent.Model
baseAgent.Provider = newAgent.Provider
base.Agents[agentID] = baseAgent
default:
baseAgent := base.Agents[agentID]
baseAgent.Name = newAgent.Name
baseAgent.Description = newAgent.Description
baseAgent.Disabled = newAgent.Disabled
baseAgent.Provider = newAgent.Provider
baseAgent.Model = newAgent.Model
baseAgent.AllowedTools = newAgent.AllowedTools
baseAgent.AllowedMCP = newAgent.AllowedMCP
baseAgent.AllowedLSP = newAgent.AllowedLSP
base.Agents[agentID] = baseAgent
}
}
}
}
}
func mergeMCPs(base, global, local *Config) {
for _, cfg := range []*Config{global, local} {
if cfg == nil {
continue
}
maps.Copy(base.MCP, cfg.MCP)
}
}
func mergeLSPs(base, global, local *Config) {
for _, cfg := range []*Config{global, local} {
if cfg == nil {
continue
}
maps.Copy(base.LSP, cfg.LSP)
}
}
func mergeProviderConfigs(base, global, local *Config) {
for _, cfg := range []*Config{global, local} {
if cfg == nil {
continue
}
for providerName, globalProvider := range cfg.Providers {
if _, ok := base.Providers[providerName]; !ok {
base.Providers[providerName] = globalProvider
} else {
base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
}
}
}
finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
for providerName, providerConfig := range base.Providers {
err := validateProvider(providerName, providerConfig)
if err != nil {
logging.Warn("Skipping provider", "name", providerName, "error", err)
}
finalProviders[providerName] = providerConfig
}
base.Providers = finalProviders
}
func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
switch providerId {
case provider.InferenceProviderAnthropic:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeAnthropic,
}
case provider.InferenceProviderOpenAI:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeOpenAI,
}
case provider.InferenceProviderGemini:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeGemini,
}
case provider.InferenceProviderBedrock:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeBedrock,
}
case provider.InferenceProviderAzure:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeAzure,
}
case provider.InferenceProviderOpenRouter:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeOpenAI,
BaseURL: "https://openrouter.ai/api/v1",
ExtraHeaders: map[string]string{
"HTTP-Referer": "crush.charm.land",
"X-Title": "Crush",
},
}
case provider.InferenceProviderXAI:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeXAI,
BaseURL: "https://api.x.ai/v1",
}
case provider.InferenceProviderVertexAI:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeVertexAI,
}
default:
return ProviderConfig{
ID: providerId,
ProviderType: provider.TypeOpenAI,
}
}
}
func defaultConfigBasedOnEnv() *Config {
cfg := &Config{
Options: Options{
DataDirectory: defaultDataDirectory,
ContextPaths: defaultContextPaths,
},
Providers: make(map[provider.InferenceProvider]ProviderConfig),
}
providers := Providers()
for _, p := range providers {
if strings.HasPrefix(p.APIKey, "$") {
envVar := strings.TrimPrefix(p.APIKey, "$")
if apiKey := os.Getenv(envVar); apiKey != "" {
providerConfig := providerDefaultConfig(p.ID)
providerConfig.APIKey = apiKey
providerConfig.DefaultLargeModel = p.DefaultLargeModelID
providerConfig.DefaultSmallModel = p.DefaultSmallModelID
for _, model := range p.Models {
providerConfig.Models = append(providerConfig.Models, Model{
ID: model.ID,
Name: model.Name,
CostPer1MIn: model.CostPer1MIn,
CostPer1MOut: model.CostPer1MOut,
CostPer1MInCached: model.CostPer1MInCached,
CostPer1MOutCached: model.CostPer1MOutCached,
ContextWindow: model.ContextWindow,
DefaultMaxTokens: model.DefaultMaxTokens,
CanReason: model.CanReason,
SupportsImages: model.SupportsImages,
})
}
cfg.Providers[p.ID] = providerConfig
}
}
}
// TODO: support local models
if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
providerConfig.ExtraParams = map[string]string{
"project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
"location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
}
cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
}
if hasAWSCredentials() {
providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
}
return cfg
}
func hasAWSCredentials() bool {
if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
return true
}
if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
return true
}
if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
return true
}
if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
return true
}
return false
}
func WorkingDirectory() string {
return cwd
}

View File

@@ -1,7 +1,6 @@
package db
import (
"context"
"database/sql"
"fmt"
"os"
@@ -16,8 +15,8 @@ import (
"github.com/pressly/goose/v3"
)
func Connect(ctx context.Context) (*sql.DB, error) {
dataDir := config.Get().Data.Directory
func Connect() (*sql.DB, error) {
dataDir := config.Get().Options.DataDirectory
if dataDir == "" {
return nil, fmt.Errorf("data.dir is not set")
}

View File

@@ -17,12 +17,13 @@ INSERT INTO messages (
role,
parts,
model,
provider,
created_at,
updated_at
) VALUES (
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at
RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
`
type CreateMessageParams struct {
@@ -31,6 +32,7 @@ type CreateMessageParams struct {
Role string `json:"role"`
Parts string `json:"parts"`
Model sql.NullString `json:"model"`
Provider sql.NullString `json:"provider"`
}
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
@@ -40,6 +42,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
arg.Role,
arg.Parts,
arg.Model,
arg.Provider,
)
var i Message
err := row.Scan(
@@ -51,6 +54,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
&i.Provider,
)
return i, err
}
@@ -76,7 +80,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
}
const getMessage = `-- name: GetMessage :one
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
FROM messages
WHERE id = ? LIMIT 1
`
@@ -93,12 +97,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
&i.Provider,
)
return i, err
}
const listMessagesBySession = `-- name: ListMessagesBySession :many
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
FROM messages
WHERE session_id = ?
ORDER BY created_at ASC
@@ -122,6 +127,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
&i.CreatedAt,
&i.UpdatedAt,
&i.FinishedAt,
&i.Provider,
); err != nil {
return nil, err
}

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
-- Add provider column to messages table
ALTER TABLE messages ADD COLUMN provider TEXT;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Remove provider column from messages table
ALTER TABLE messages DROP COLUMN provider;
-- +goose StatementEnd

View File

@@ -27,6 +27,7 @@ type Message struct {
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
FinishedAt sql.NullInt64 `json:"finished_at"`
Provider sql.NullString `json:"provider"`
}
type Session struct {

View File

@@ -16,10 +16,11 @@ INSERT INTO messages (
role,
parts,
model,
provider,
created_at,
updated_at
) VALUES (
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
)
RETURNING *;

View File

@@ -10,7 +10,7 @@ import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
const defaultURL = "http://localhost:8080"
const defaultURL = "https://fur.charmcli.dev"
// Client represents a client for the fur service.
type Client struct {

View File

@@ -6,14 +6,13 @@ type Type string
// All the supported AI provider types.
const (
TypeOpenAI Type = "openai"
TypeAnthropic Type = "anthropic"
TypeGemini Type = "gemini"
TypeAzure Type = "azure"
TypeBedrock Type = "bedrock"
TypeVertexAI Type = "vertexai"
TypeXAI Type = "xai"
TypeOpenRouter Type = "openrouter"
TypeOpenAI Type = "openai"
TypeAnthropic Type = "anthropic"
TypeGemini Type = "gemini"
TypeAzure Type = "azure"
TypeBedrock Type = "bedrock"
TypeVertexAI Type = "vertexai"
TypeXAI Type = "xai"
)
// InferenceProvider represents the inference provider identifier.

View File

@@ -5,17 +5,15 @@ import (
"encoding/json"
"fmt"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
)
type agentTool struct {
sessions session.Service
messages message.Service
lspClients map[string]*lsp.Client
agent Service
sessions session.Service
messages message.Service
}
const (
@@ -58,17 +56,12 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
}
agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
}
session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
}
done, err := agent.Run(ctx, session.ID, params.Prompt)
done, err := b.agent.Run(ctx, session.ID, params.Prompt)
if err != nil {
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
}
@@ -101,13 +94,13 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
}
func NewAgentTool(
Sessions session.Service,
Messages message.Service,
LspClients map[string]*lsp.Client,
agent Service,
sessions session.Service,
messages message.Service,
) tools.BaseTool {
return &agentTool{
sessions: Sessions,
messages: Messages,
lspClients: LspClients,
sessions: sessions,
messages: messages,
agent: agent,
}
}

View File

@@ -4,16 +4,18 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"sync"
"time"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
configv2 "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -47,71 +49,198 @@ type AgentEvent struct {
type Service interface {
pubsub.Suscriber[AgentEvent]
Model() models.Model
Model() configv2.Model
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
CancelAll()
IsSessionBusy(sessionID string) bool
IsBusy() bool
Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
Update(model configv2.PreferredModel) (configv2.Model, error)
Summarize(ctx context.Context, sessionID string) error
}
type agent struct {
*pubsub.Broker[AgentEvent]
agentCfg configv2.Agent
sessions session.Service
messages message.Service
tools []tools.BaseTool
provider provider.Provider
tools []tools.BaseTool
provider provider.Provider
providerID string
titleProvider provider.Provider
summarizeProvider provider.Provider
titleProvider provider.Provider
summarizeProvider provider.Provider
summarizeProviderID string
activeRequests sync.Map
}
var agentPromptMap = map[configv2.AgentID]prompt.PromptID{
configv2.AgentCoder: prompt.PromptCoder,
configv2.AgentTask: prompt.PromptTask,
}
func NewAgent(
agentName config.AgentName,
agentCfg configv2.Agent,
// These services are needed in the tools
permissions permission.Service,
sessions session.Service,
messages message.Service,
agentTools []tools.BaseTool,
history history.Service,
lspClients map[string]*lsp.Client,
) (Service, error) {
agentProvider, err := createAgentProvider(agentName)
ctx := context.Background()
cfg := configv2.Get()
otherTools := GetMcpTools(ctx, permissions)
if len(lspClients) > 0 {
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
}
allTools := []tools.BaseTool{
tools.NewBashTool(permissions),
tools.NewEditTool(lspClients, permissions, history),
tools.NewFetchTool(permissions),
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients),
tools.NewWriteTool(lspClients, permissions, history),
}
if agentCfg.ID == configv2.AgentCoder {
taskAgentCfg := configv2.Get().Agents[configv2.AgentTask]
if taskAgentCfg.ID == "" {
return nil, fmt.Errorf("task agent not found in config")
}
taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
if err != nil {
return nil, fmt.Errorf("failed to create task agent: %w", err)
}
allTools = append(
allTools,
NewAgentTool(
taskAgent,
sessions,
messages,
),
)
}
allTools = append(allTools, otherTools...)
var providerCfg configv2.ProviderConfig
for _, p := range cfg.Providers {
if p.ID == agentCfg.Provider {
providerCfg = p
break
}
}
if providerCfg.ID == "" {
return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider)
}
var model configv2.Model
for _, m := range providerCfg.Models {
if m.ID == agentCfg.Model {
model = m
break
}
}
if model.ID == "" {
return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider)
}
promptID := agentPromptMap[agentCfg.ID]
if promptID == "" {
promptID = prompt.PromptDefault
}
opts := []provider.ProviderClientOption{
provider.WithModel(model),
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
provider.WithMaxTokens(model.DefaultMaxTokens),
}
agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
if err != nil {
return nil, err
}
var titleProvider provider.Provider
// Only generate titles for the coder agent
if agentName == config.AgentCoder {
titleProvider, err = createAgentProvider(config.AgentTitle)
if err != nil {
return nil, err
smallModelCfg := cfg.Models.Small
var smallModel configv2.Model
var smallModelProviderCfg configv2.ProviderConfig
if smallModelCfg.Provider == providerCfg.ID {
smallModelProviderCfg = providerCfg
} else {
for _, p := range cfg.Providers {
if p.ID == smallModelCfg.Provider {
smallModelProviderCfg = p
break
}
}
if smallModelProviderCfg.ID == "" {
return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
}
}
var summarizeProvider provider.Provider
if agentName == config.AgentCoder {
summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
if err != nil {
return nil, err
for _, m := range smallModelProviderCfg.Models {
if m.ID == smallModelCfg.ModelID {
smallModel = m
break
}
}
if smallModel.ID == "" {
return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
}
titleOpts := []provider.ProviderClientOption{
provider.WithModel(smallModel),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
provider.WithMaxTokens(40),
}
titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
if err != nil {
return nil, err
}
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(smallModel),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
provider.WithMaxTokens(smallModel.DefaultMaxTokens),
}
summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
if err != nil {
return nil, err
}
agentTools := []tools.BaseTool{}
if agentCfg.AllowedTools == nil {
agentTools = allTools
} else {
for _, tool := range allTools {
if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
agentTools = append(agentTools, tool)
}
}
}
agent := &agent{
Broker: pubsub.NewBroker[AgentEvent](),
provider: agentProvider,
messages: messages,
sessions: sessions,
tools: agentTools,
titleProvider: titleProvider,
summarizeProvider: summarizeProvider,
activeRequests: sync.Map{},
Broker: pubsub.NewBroker[AgentEvent](),
agentCfg: agentCfg,
provider: agentProvider,
providerID: string(providerCfg.ID),
messages: messages,
sessions: sessions,
tools: agentTools,
titleProvider: titleProvider,
summarizeProvider: summarizeProvider,
summarizeProviderID: string(smallModelProviderCfg.ID),
activeRequests: sync.Map{},
}
return agent, nil
}
func (a *agent) Model() models.Model {
func (a *agent) Model() configv2.Model {
return a.provider.Model()
}
@@ -207,7 +336,7 @@ func (a *agent) err(err error) AgentEvent {
}
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
if !a.provider.Model().SupportsAttachments && attachments != nil {
if !a.provider.Model().SupportsImages && attachments != nil {
attachments = nil
}
events := make(chan AgentEvent)
@@ -327,9 +456,10 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
Model: a.provider.Model().ID,
Role: message.Assistant,
Parts: []message.ContentPart{},
Model: a.provider.Model().ID,
Provider: a.providerID,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
@@ -424,8 +554,9 @@ out:
parts = append(parts, tr)
}
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
Role: message.Tool,
Parts: parts,
Provider: a.providerID,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
@@ -484,7 +615,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
return nil
}
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error {
sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -506,21 +637,48 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
return nil
}
func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) {
if a.IsBusy() {
return models.Model{}, fmt.Errorf("cannot change model while processing requests")
return configv2.Model{}, fmt.Errorf("cannot change model while processing requests")
}
if err := config.UpdateAgentModel(agentName, modelID); err != nil {
return models.Model{}, fmt.Errorf("failed to update config: %w", err)
cfg := configv2.Get()
var providerCfg configv2.ProviderConfig
for _, p := range cfg.Providers {
if p.ID == modelCfg.Provider {
providerCfg = p
break
}
}
if providerCfg.ID == "" {
return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider)
}
provider, err := createAgentProvider(agentName)
var model configv2.Model
for _, m := range providerCfg.Models {
if m.ID == modelCfg.ModelID {
model = m
break
}
}
if model.ID == "" {
return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider)
}
promptID := agentPromptMap[a.agentCfg.ID]
if promptID == "" {
promptID = prompt.PromptDefault
}
opts := []provider.ProviderClientOption{
provider.WithModel(model),
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
provider.WithMaxTokens(model.DefaultMaxTokens),
}
agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
if err != nil {
return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
return configv2.Model{}, err
}
a.provider = provider
a.provider = agentProvider
return a.provider.Model(), nil
}
@@ -654,7 +812,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
Time: time.Now().Unix(),
},
},
Model: a.summarizeProvider.Model().ID,
Model: a.summarizeProvider.Model().ID,
Provider: a.summarizeProviderID,
})
if err != nil {
event = AgentEvent{
@@ -705,51 +864,3 @@ func (a *agent) CancelAll() {
return true
})
}
func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
cfg := config.Get()
agentConfig, ok := cfg.Agents[agentName]
if !ok {
return nil, fmt.Errorf("agent %s not found", agentName)
}
model, ok := models.SupportedModels[agentConfig.Model]
if !ok {
return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
}
providerCfg, ok := cfg.Providers[model.Provider]
if !ok {
return nil, fmt.Errorf("provider %s not supported", model.Provider)
}
if providerCfg.Disabled {
return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
}
maxTokens := model.DefaultMaxTokens
if agentConfig.MaxTokens > 0 {
maxTokens = agentConfig.MaxTokens
}
opts := []provider.ProviderClientOption{
provider.WithAPIKey(providerCfg.APIKey),
provider.WithModel(model),
provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
provider.WithMaxTokens(maxTokens),
}
// TODO: reimplement
// if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
// opts = append(
// opts,
// provider.WithOpenAIOptions(
// provider.WithReasoningEffort(agentConfig.ReasoningEffort),
// ),
// )
// }
agentProvider, err := provider.NewProvider(
model.Provider,
opts...,
)
if err != nil {
return nil, fmt.Errorf("could not create provider: %v", err)
}
return agentProvider, nil
}

View File

@@ -18,7 +18,7 @@ import (
type mcpTool struct {
mcpName string
tool mcp.Tool
mcpConfig config.MCPServer
mcpConfig config.MCP
permissions permission.Service
}
@@ -128,7 +128,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
@@ -139,7 +139,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
var mcpTools []tools.BaseTool
func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
@@ -170,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
if len(mcpTools) > 0 {
return mcpTools
}
for name, m := range config.Get().MCPServers {
for name, m := range config.Get().MCP {
switch m.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(

View File

@@ -1,50 +0,0 @@
package agent
import (
"context"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/session"
)
func CoderAgentTools(
permissions permission.Service,
sessions session.Service,
messages message.Service,
history history.Service,
lspClients map[string]*lsp.Client,
) []tools.BaseTool {
ctx := context.Background()
otherTools := GetMcpTools(ctx, permissions)
if len(lspClients) > 0 {
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
}
return append(
[]tools.BaseTool{
tools.NewBashTool(permissions),
tools.NewEditTool(lspClients, permissions, history),
tools.NewFetchTool(permissions),
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients),
tools.NewWriteTool(lspClients, permissions, history),
NewAgentTool(sessions, messages, lspClients),
}, otherTools...,
)
}
func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
return []tools.BaseTool{
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients),
}
}

View File

@@ -1,111 +0,0 @@
package models
const (
ProviderAnthropic InferenceProvider = "anthropic"
// Models
Claude35Sonnet ModelID = "claude-3.5-sonnet"
Claude3Haiku ModelID = "claude-3-haiku"
Claude37Sonnet ModelID = "claude-3.7-sonnet"
Claude35Haiku ModelID = "claude-3.5-haiku"
Claude3Opus ModelID = "claude-3-opus"
Claude4Opus ModelID = "claude-4-opus"
Claude4Sonnet ModelID = "claude-4-sonnet"
)
// https://docs.anthropic.com/en/docs/about-claude/models/all-models
var AnthropicModels = map[ModelID]Model{
Claude35Sonnet: {
ID: Claude35Sonnet,
Name: "Claude 3.5 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-5-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200000,
DefaultMaxTokens: 5000,
SupportsAttachments: true,
},
Claude3Haiku: {
ID: Claude3Haiku,
Name: "Claude 3 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
CostPer1MIn: 0.25,
CostPer1MInCached: 0.30,
CostPer1MOutCached: 0.03,
CostPer1MOut: 1.25,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
Claude37Sonnet: {
ID: Claude37Sonnet,
Name: "Claude 3.7 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-3-7-sonnet-latest",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
Claude35Haiku: {
ID: Claude35Haiku,
Name: "Claude 3.5 Haiku",
Provider: ProviderAnthropic,
APIModel: "claude-3-5-haiku-latest",
CostPer1MIn: 0.80,
CostPer1MInCached: 1.0,
CostPer1MOutCached: 0.08,
CostPer1MOut: 4.0,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
Claude3Opus: {
ID: Claude3Opus,
Name: "Claude 3 Opus",
Provider: ProviderAnthropic,
APIModel: "claude-3-opus-latest",
CostPer1MIn: 15.0,
CostPer1MInCached: 18.75,
CostPer1MOutCached: 1.50,
CostPer1MOut: 75.0,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
Claude4Sonnet: {
ID: Claude4Sonnet,
Name: "Claude 4 Sonnet",
Provider: ProviderAnthropic,
APIModel: "claude-sonnet-4-20250514",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
Claude4Opus: {
ID: Claude4Opus,
Name: "Claude 4 Opus",
Provider: ProviderAnthropic,
APIModel: "claude-opus-4-20250514",
CostPer1MIn: 15.0,
CostPer1MInCached: 18.75,
CostPer1MOutCached: 1.50,
CostPer1MOut: 75.0,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
}

View File

@@ -1,168 +0,0 @@
package models
const ProviderAzure InferenceProvider = "azure"
const (
AzureGPT41 ModelID = "azure.gpt-4.1"
AzureGPT41Mini ModelID = "azure.gpt-4.1-mini"
AzureGPT41Nano ModelID = "azure.gpt-4.1-nano"
AzureGPT45Preview ModelID = "azure.gpt-4.5-preview"
AzureGPT4o ModelID = "azure.gpt-4o"
AzureGPT4oMini ModelID = "azure.gpt-4o-mini"
AzureO1 ModelID = "azure.o1"
AzureO1Mini ModelID = "azure.o1-mini"
AzureO3 ModelID = "azure.o3"
AzureO3Mini ModelID = "azure.o3-mini"
AzureO4Mini ModelID = "azure.o4-mini"
)
var AzureModels = map[ModelID]Model{
AzureGPT41: {
ID: AzureGPT41,
Name: "Azure OpenAI GPT 4.1",
Provider: ProviderAzure,
APIModel: "gpt-4.1",
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT41Mini: {
ID: AzureGPT41Mini,
Name: "Azure OpenAI GPT 4.1 mini",
Provider: ProviderAzure,
APIModel: "gpt-4.1-mini",
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT41Nano: {
ID: AzureGPT41Nano,
Name: "Azure OpenAI GPT 4.1 nano",
Provider: ProviderAzure,
APIModel: "gpt-4.1-nano",
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT45Preview: {
ID: AzureGPT45Preview,
Name: "Azure OpenAI GPT 4.5 preview",
Provider: ProviderAzure,
APIModel: "gpt-4.5-preview",
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT4o: {
ID: AzureGPT4o,
Name: "Azure OpenAI GPT-4o",
Provider: ProviderAzure,
APIModel: "gpt-4o",
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureGPT4oMini: {
ID: AzureGPT4oMini,
Name: "Azure OpenAI GPT-4o mini",
Provider: ProviderAzure,
APIModel: "gpt-4o-mini",
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
SupportsAttachments: true,
},
AzureO1: {
ID: AzureO1,
Name: "Azure OpenAI O1",
Provider: ProviderAzure,
APIModel: "o1",
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
CanReason: OpenAIModels[O1].CanReason,
SupportsAttachments: true,
},
AzureO1Mini: {
ID: AzureO1Mini,
Name: "Azure OpenAI O1 mini",
Provider: ProviderAzure,
APIModel: "o1-mini",
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O1Mini].CanReason,
SupportsAttachments: true,
},
AzureO3: {
ID: AzureO3,
Name: "Azure OpenAI O3",
Provider: ProviderAzure,
APIModel: "o3",
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
CanReason: OpenAIModels[O3].CanReason,
SupportsAttachments: true,
},
AzureO3Mini: {
ID: AzureO3Mini,
Name: "Azure OpenAI O3 mini",
Provider: ProviderAzure,
APIModel: "o3-mini",
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O3Mini].CanReason,
SupportsAttachments: false,
},
AzureO4Mini: {
ID: AzureO4Mini,
Name: "Azure OpenAI O4 mini",
Provider: ProviderAzure,
APIModel: "o4-mini",
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O4Mini].CanReason,
SupportsAttachments: true,
},
}

View File

@@ -1,67 +0,0 @@
package models
const (
ProviderGemini InferenceProvider = "gemini"
// Models
Gemini25Flash ModelID = "gemini-2.5-flash"
Gemini25 ModelID = "gemini-2.5"
Gemini20Flash ModelID = "gemini-2.0-flash"
Gemini20FlashLite ModelID = "gemini-2.0-flash-lite"
)
var GeminiModels = map[ModelID]Model{
Gemini25Flash: {
ID: Gemini25Flash,
Name: "Gemini 2.5 Flash",
Provider: ProviderGemini,
APIModel: "gemini-2.5-flash-preview-04-17",
CostPer1MIn: 0.15,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.60,
ContextWindow: 1000000,
DefaultMaxTokens: 50000,
SupportsAttachments: true,
},
Gemini25: {
ID: Gemini25,
Name: "Gemini 2.5 Pro",
Provider: ProviderGemini,
APIModel: "gemini-2.5-pro-preview-05-06",
CostPer1MIn: 1.25,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 10,
ContextWindow: 1000000,
DefaultMaxTokens: 50000,
SupportsAttachments: true,
},
Gemini20Flash: {
ID: Gemini20Flash,
Name: "Gemini 2.0 Flash",
Provider: ProviderGemini,
APIModel: "gemini-2.0-flash",
CostPer1MIn: 0.10,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.40,
ContextWindow: 1000000,
DefaultMaxTokens: 6000,
SupportsAttachments: true,
},
Gemini20FlashLite: {
ID: Gemini20FlashLite,
Name: "Gemini 2.0 Flash Lite",
Provider: ProviderGemini,
APIModel: "gemini-2.0-flash-lite",
CostPer1MIn: 0.05,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.30,
ContextWindow: 1000000,
DefaultMaxTokens: 6000,
SupportsAttachments: true,
},
}

View File

@@ -1,87 +0,0 @@
package models
const (
ProviderGROQ InferenceProvider = "groq"
// GROQ
QWENQwq ModelID = "qwen-qwq"
// GROQ preview models
Llama4Scout ModelID = "meta-llama/llama-4-scout-17b-16e-instruct"
Llama4Maverick ModelID = "meta-llama/llama-4-maverick-17b-128e-instruct"
Llama3_3_70BVersatile ModelID = "llama-3.3-70b-versatile"
DeepseekR1DistillLlama70b ModelID = "deepseek-r1-distill-llama-70b"
)
var GroqModels = map[ModelID]Model{
//
// GROQ
QWENQwq: {
ID: QWENQwq,
Name: "Qwen Qwq",
Provider: ProviderGROQ,
APIModel: "qwen-qwq-32b",
CostPer1MIn: 0.29,
CostPer1MInCached: 0.275,
CostPer1MOutCached: 0.0,
CostPer1MOut: 0.39,
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
// for some reason, the groq api doesn't like the reasoningEffort parameter
CanReason: false,
SupportsAttachments: false,
},
Llama4Scout: {
ID: Llama4Scout,
Name: "Llama4Scout",
Provider: ProviderGROQ,
APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
CostPer1MIn: 0.11,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.34,
ContextWindow: 128_000, // 10M when?
SupportsAttachments: true,
},
Llama4Maverick: {
ID: Llama4Maverick,
Name: "Llama4Maverick",
Provider: ProviderGROQ,
APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
CostPer1MIn: 0.20,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.20,
ContextWindow: 128_000,
SupportsAttachments: true,
},
Llama3_3_70BVersatile: {
ID: Llama3_3_70BVersatile,
Name: "Llama3_3_70BVersatile",
Provider: ProviderGROQ,
APIModel: "llama-3.3-70b-versatile",
CostPer1MIn: 0.59,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.79,
ContextWindow: 128_000,
SupportsAttachments: false,
},
DeepseekR1DistillLlama70b: {
ID: DeepseekR1DistillLlama70b,
Name: "DeepseekR1DistillLlama70b",
Provider: ProviderGROQ,
APIModel: "deepseek-r1-distill-llama-70b",
CostPer1MIn: 0.75,
CostPer1MInCached: 0,
CostPer1MOutCached: 0,
CostPer1MOut: 0.99,
ContextWindow: 128_000,
CanReason: true,
SupportsAttachments: false,
},
}

View File

@@ -1,206 +0,0 @@
package models
import (
"cmp"
"context"
"encoding/json"
"net/http"
"net/url"
"os"
"regexp"
"strings"
"unicode"
"github.com/charmbracelet/crush/internal/logging"
"github.com/spf13/viper"
)
const (
ProviderLocal InferenceProvider = "local"
localModelsPath = "v1/models"
lmStudioBetaModelsPath = "api/v0/models"
)
func init() {
if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
localEndpoint, err := url.Parse(endpoint)
if err != nil {
logging.Debug("Failed to parse local endpoint",
"error", err,
"endpoint", endpoint,
)
return
}
load := func(url *url.URL, path string) []localModel {
url.Path = path
return listLocalModels(url.String())
}
models := load(localEndpoint, lmStudioBetaModelsPath)
if len(models) == 0 {
models = load(localEndpoint, localModelsPath)
}
if len(models) == 0 {
logging.Debug("No local models found",
"endpoint", endpoint,
)
return
}
loadLocalModels(models)
viper.SetDefault("providers.local.apiKey", "dummy")
}
}
type localModelList struct {
Data []localModel `json:"data"`
}
type localModel struct {
ID string `json:"id"`
Object string `json:"object"`
Type string `json:"type"`
Publisher string `json:"publisher"`
Arch string `json:"arch"`
CompatibilityType string `json:"compatibility_type"`
Quantization string `json:"quantization"`
State string `json:"state"`
MaxContextLength int64 `json:"max_context_length"`
LoadedContextLength int64 `json:"loaded_context_length"`
}
func listLocalModels(modelsEndpoint string) []localModel {
res, err := http.NewRequestWithContext(context.Background(), http.MethodGet, modelsEndpoint, nil)
if err != nil {
logging.Debug("Failed to list local models",
"error", err,
"endpoint", modelsEndpoint,
)
}
defer res.Body.Close()
if res.Response.StatusCode != http.StatusOK {
logging.Debug("Failed to list local models",
"status", res.Response.Status,
"endpoint", modelsEndpoint,
)
}
var modelList localModelList
if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
logging.Debug("Failed to list local models",
"error", err,
"endpoint", modelsEndpoint,
)
}
var supportedModels []localModel
for _, model := range modelList.Data {
if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
if model.Object != "model" || model.Type != "llm" {
logging.Debug("Skipping unsupported LMStudio model",
"endpoint", modelsEndpoint,
"id", model.ID,
"object", model.Object,
"type", model.Type,
)
continue
}
}
supportedModels = append(supportedModels, model)
}
return supportedModels
}
func loadLocalModels(models []localModel) {
for i, m := range models {
model := convertLocalModel(m)
SupportedModels[model.ID] = model
if i == 0 || m.State == "loaded" {
viper.SetDefault("agents.coder.model", model.ID)
viper.SetDefault("agents.summarizer.model", model.ID)
viper.SetDefault("agents.task.model", model.ID)
viper.SetDefault("agents.title.model", model.ID)
}
}
}
func convertLocalModel(model localModel) Model {
return Model{
ID: ModelID("local." + model.ID),
Name: friendlyModelName(model.ID),
Provider: ProviderLocal,
APIModel: model.ID,
ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
CanReason: true,
SupportsAttachments: true,
}
}
var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
func friendlyModelName(modelID string) string {
mainID := modelID
tag := ""
if slash := strings.LastIndex(mainID, "/"); slash != -1 {
mainID = mainID[slash+1:]
}
if at := strings.Index(modelID, "@"); at != -1 {
mainID = modelID[:at]
tag = modelID[at+1:]
}
match := modelInfoRegex.FindStringSubmatch(mainID)
if match == nil {
return modelID
}
capitalize := func(s string) string {
if s == "" {
return ""
}
runes := []rune(s)
runes[0] = unicode.ToUpper(runes[0])
return string(runes)
}
family := capitalize(match[1])
version := ""
label := ""
if len(match) > 2 && match[2] != "" {
version = strings.ToUpper(match[2])
}
if len(match) > 3 && match[3] != "" {
label = capitalize(match[3])
}
var parts []string
if family != "" {
parts = append(parts, family)
}
if version != "" {
parts = append(parts, version)
}
if label != "" {
parts = append(parts, label)
}
if tag != "" {
parts = append(parts, tag)
}
return strings.Join(parts, " ")
}

View File

@@ -1,74 +0,0 @@
package models
import "maps"
type (
ModelID string
InferenceProvider string
)
type Model struct {
ID ModelID `json:"id"`
Name string `json:"name"`
Provider InferenceProvider `json:"provider"`
APIModel string `json:"api_model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
ContextWindow int64 `json:"context_window"`
DefaultMaxTokens int64 `json:"default_max_tokens"`
CanReason bool `json:"can_reason"`
SupportsAttachments bool `json:"supports_attachments"`
}
// Model IDs
const ( // GEMINI
// Bedrock
BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
)
const (
ProviderBedrock InferenceProvider = "bedrock"
// ForTests
ProviderMock InferenceProvider = "__mock"
)
var SupportedModels = map[ModelID]Model{
// Bedrock
BedrockClaude37Sonnet: {
ID: BedrockClaude37Sonnet,
Name: "Bedrock: Claude 3.7 Sonnet",
Provider: ProviderBedrock,
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
},
}
var KnownProviders = []InferenceProvider{
ProviderAnthropic,
ProviderOpenAI,
ProviderGemini,
ProviderAzure,
ProviderGROQ,
ProviderLocal,
ProviderOpenRouter,
ProviderVertexAI,
ProviderBedrock,
ProviderXAI,
ProviderMock,
}
func init() {
maps.Copy(SupportedModels, AnthropicModels)
maps.Copy(SupportedModels, OpenAIModels)
maps.Copy(SupportedModels, GeminiModels)
maps.Copy(SupportedModels, GroqModels)
maps.Copy(SupportedModels, AzureModels)
maps.Copy(SupportedModels, OpenRouterModels)
maps.Copy(SupportedModels, XAIModels)
maps.Copy(SupportedModels, VertexAIGeminiModels)
}

View File

@@ -1,181 +0,0 @@
package models
const (
ProviderOpenAI InferenceProvider = "openai"
GPT41 ModelID = "gpt-4.1"
GPT41Mini ModelID = "gpt-4.1-mini"
GPT41Nano ModelID = "gpt-4.1-nano"
GPT45Preview ModelID = "gpt-4.5-preview"
GPT4o ModelID = "gpt-4o"
GPT4oMini ModelID = "gpt-4o-mini"
O1 ModelID = "o1"
O1Pro ModelID = "o1-pro"
O1Mini ModelID = "o1-mini"
O3 ModelID = "o3"
O3Mini ModelID = "o3-mini"
O4Mini ModelID = "o4-mini"
)
var OpenAIModels = map[ModelID]Model{
GPT41: {
ID: GPT41,
Name: "GPT 4.1",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1",
CostPer1MIn: 2.00,
CostPer1MInCached: 0.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 8.00,
ContextWindow: 1_047_576,
DefaultMaxTokens: 20000,
SupportsAttachments: true,
},
GPT41Mini: {
ID: GPT41Mini,
Name: "GPT 4.1 mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1",
CostPer1MIn: 0.40,
CostPer1MInCached: 0.10,
CostPer1MOutCached: 0.0,
CostPer1MOut: 1.60,
ContextWindow: 200_000,
DefaultMaxTokens: 20000,
SupportsAttachments: true,
},
GPT41Nano: {
ID: GPT41Nano,
Name: "GPT 4.1 nano",
Provider: ProviderOpenAI,
APIModel: "gpt-4.1-nano",
CostPer1MIn: 0.10,
CostPer1MInCached: 0.025,
CostPer1MOutCached: 0.0,
CostPer1MOut: 0.40,
ContextWindow: 1_047_576,
DefaultMaxTokens: 20000,
SupportsAttachments: true,
},
GPT45Preview: {
ID: GPT45Preview,
Name: "GPT 4.5 preview",
Provider: ProviderOpenAI,
APIModel: "gpt-4.5-preview",
CostPer1MIn: 75.00,
CostPer1MInCached: 37.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 150.00,
ContextWindow: 128_000,
DefaultMaxTokens: 15000,
SupportsAttachments: true,
},
GPT4o: {
ID: GPT4o,
Name: "GPT 4o",
Provider: ProviderOpenAI,
APIModel: "gpt-4o",
CostPer1MIn: 2.50,
CostPer1MInCached: 1.25,
CostPer1MOutCached: 0.0,
CostPer1MOut: 10.00,
ContextWindow: 128_000,
DefaultMaxTokens: 4096,
SupportsAttachments: true,
},
GPT4oMini: {
ID: GPT4oMini,
Name: "GPT 4o mini",
Provider: ProviderOpenAI,
APIModel: "gpt-4o-mini",
CostPer1MIn: 0.15,
CostPer1MInCached: 0.075,
CostPer1MOutCached: 0.0,
CostPer1MOut: 0.60,
ContextWindow: 128_000,
SupportsAttachments: true,
},
O1: {
ID: O1,
Name: "O1",
Provider: ProviderOpenAI,
APIModel: "o1",
CostPer1MIn: 15.00,
CostPer1MInCached: 7.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 60.00,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
O1Pro: {
ID: O1Pro,
Name: "o1 pro",
Provider: ProviderOpenAI,
APIModel: "o1-pro",
CostPer1MIn: 150.00,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.0,
CostPer1MOut: 600.00,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
O1Mini: {
ID: O1Mini,
Name: "o1 mini",
Provider: ProviderOpenAI,
APIModel: "o1-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.55,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
O3: {
ID: O3,
Name: "o3",
Provider: ProviderOpenAI,
APIModel: "o3",
CostPer1MIn: 10.00,
CostPer1MInCached: 2.50,
CostPer1MOutCached: 0.0,
CostPer1MOut: 40.00,
ContextWindow: 200_000,
CanReason: true,
SupportsAttachments: true,
},
O3Mini: {
ID: O3Mini,
Name: "o3 mini",
Provider: ProviderOpenAI,
APIModel: "o3-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.55,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 200_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: false,
},
O4Mini: {
ID: O4Mini,
Name: "o4 mini",
Provider: ProviderOpenAI,
APIModel: "o4-mini",
CostPer1MIn: 1.10,
CostPer1MInCached: 0.275,
CostPer1MOutCached: 0.0,
CostPer1MOut: 4.40,
ContextWindow: 128_000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsAttachments: true,
},
}

View File

@@ -1,276 +0,0 @@
package models
const (
ProviderOpenRouter InferenceProvider = "openrouter"
OpenRouterGPT41 ModelID = "openrouter.gpt-4.1"
OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini"
OpenRouterGPT41Nano ModelID = "openrouter.gpt-4.1-nano"
OpenRouterGPT45Preview ModelID = "openrouter.gpt-4.5-preview"
OpenRouterGPT4o ModelID = "openrouter.gpt-4o"
OpenRouterGPT4oMini ModelID = "openrouter.gpt-4o-mini"
OpenRouterO1 ModelID = "openrouter.o1"
OpenRouterO1Pro ModelID = "openrouter.o1-pro"
OpenRouterO1Mini ModelID = "openrouter.o1-mini"
OpenRouterO3 ModelID = "openrouter.o3"
OpenRouterO3Mini ModelID = "openrouter.o3-mini"
OpenRouterO4Mini ModelID = "openrouter.o4-mini"
OpenRouterGemini25Flash ModelID = "openrouter.gemini-2.5-flash"
OpenRouterGemini25 ModelID = "openrouter.gemini-2.5"
OpenRouterClaude35Sonnet ModelID = "openrouter.claude-3.5-sonnet"
OpenRouterClaude3Haiku ModelID = "openrouter.claude-3-haiku"
OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet"
OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku"
OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus"
OpenRouterDeepSeekR1Free ModelID = "openrouter.deepseek-r1-free"
)
var OpenRouterModels = map[ModelID]Model{
OpenRouterGPT41: {
ID: OpenRouterGPT41,
Name: "OpenRouter GPT 4.1",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.1",
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
},
OpenRouterGPT41Mini: {
ID: OpenRouterGPT41Mini,
Name: "OpenRouter GPT 4.1 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.1-mini",
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
},
OpenRouterGPT41Nano: {
ID: OpenRouterGPT41Nano,
Name: "OpenRouter GPT 4.1 nano",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.1-nano",
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
},
OpenRouterGPT45Preview: {
ID: OpenRouterGPT45Preview,
Name: "OpenRouter GPT 4.5 preview",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4.5-preview",
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
},
OpenRouterGPT4o: {
ID: OpenRouterGPT4o,
Name: "OpenRouter GPT 4o",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4o",
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
},
OpenRouterGPT4oMini: {
ID: OpenRouterGPT4oMini,
Name: "OpenRouter GPT 4o mini",
Provider: ProviderOpenRouter,
APIModel: "openai/gpt-4o-mini",
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
},
OpenRouterO1: {
ID: OpenRouterO1,
Name: "OpenRouter O1",
Provider: ProviderOpenRouter,
APIModel: "openai/o1",
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
CanReason: OpenAIModels[O1].CanReason,
},
OpenRouterO1Pro: {
ID: OpenRouterO1Pro,
Name: "OpenRouter o1 pro",
Provider: ProviderOpenRouter,
APIModel: "openai/o1-pro",
CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1Pro].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1Pro].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1Pro].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1Pro].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1Pro].DefaultMaxTokens,
CanReason: OpenAIModels[O1Pro].CanReason,
},
OpenRouterO1Mini: {
ID: OpenRouterO1Mini,
Name: "OpenRouter o1 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/o1-mini",
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O1Mini].CanReason,
},
OpenRouterO3: {
ID: OpenRouterO3,
Name: "OpenRouter o3",
Provider: ProviderOpenRouter,
APIModel: "openai/o3",
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
CanReason: OpenAIModels[O3].CanReason,
},
OpenRouterO3Mini: {
ID: OpenRouterO3Mini,
Name: "OpenRouter o3 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/o3-mini-high",
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O3Mini].CanReason,
},
OpenRouterO4Mini: {
ID: OpenRouterO4Mini,
Name: "OpenRouter o4 mini",
Provider: ProviderOpenRouter,
APIModel: "openai/o4-mini-high",
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
CanReason: OpenAIModels[O4Mini].CanReason,
},
OpenRouterGemini25Flash: {
ID: OpenRouterGemini25Flash,
Name: "OpenRouter Gemini 2.5 Flash",
Provider: ProviderOpenRouter,
APIModel: "google/gemini-2.5-flash-preview:thinking",
CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
},
OpenRouterGemini25: {
ID: OpenRouterGemini25,
Name: "OpenRouter Gemini 2.5 Pro",
Provider: ProviderOpenRouter,
APIModel: "google/gemini-2.5-pro-preview-03-25",
CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
ContextWindow: GeminiModels[Gemini25].ContextWindow,
DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
},
OpenRouterClaude35Sonnet: {
ID: OpenRouterClaude35Sonnet,
Name: "OpenRouter Claude 3.5 Sonnet",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3.5-sonnet",
CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn,
CostPer1MInCached: AnthropicModels[Claude35Sonnet].CostPer1MInCached,
CostPer1MOut: AnthropicModels[Claude35Sonnet].CostPer1MOut,
CostPer1MOutCached: AnthropicModels[Claude35Sonnet].CostPer1MOutCached,
ContextWindow: AnthropicModels[Claude35Sonnet].ContextWindow,
DefaultMaxTokens: AnthropicModels[Claude35Sonnet].DefaultMaxTokens,
},
OpenRouterClaude3Haiku: {
ID: OpenRouterClaude3Haiku,
Name: "OpenRouter Claude 3 Haiku",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3-haiku",
CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn,
CostPer1MInCached: AnthropicModels[Claude3Haiku].CostPer1MInCached,
CostPer1MOut: AnthropicModels[Claude3Haiku].CostPer1MOut,
CostPer1MOutCached: AnthropicModels[Claude3Haiku].CostPer1MOutCached,
ContextWindow: AnthropicModels[Claude3Haiku].ContextWindow,
DefaultMaxTokens: AnthropicModels[Claude3Haiku].DefaultMaxTokens,
},
OpenRouterClaude37Sonnet: {
ID: OpenRouterClaude37Sonnet,
Name: "OpenRouter Claude 3.7 Sonnet",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3.7-sonnet",
CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn,
CostPer1MInCached: AnthropicModels[Claude37Sonnet].CostPer1MInCached,
CostPer1MOut: AnthropicModels[Claude37Sonnet].CostPer1MOut,
CostPer1MOutCached: AnthropicModels[Claude37Sonnet].CostPer1MOutCached,
ContextWindow: AnthropicModels[Claude37Sonnet].ContextWindow,
DefaultMaxTokens: AnthropicModels[Claude37Sonnet].DefaultMaxTokens,
CanReason: AnthropicModels[Claude37Sonnet].CanReason,
},
OpenRouterClaude35Haiku: {
ID: OpenRouterClaude35Haiku,
Name: "OpenRouter Claude 3.5 Haiku",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3.5-haiku",
CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn,
CostPer1MInCached: AnthropicModels[Claude35Haiku].CostPer1MInCached,
CostPer1MOut: AnthropicModels[Claude35Haiku].CostPer1MOut,
CostPer1MOutCached: AnthropicModels[Claude35Haiku].CostPer1MOutCached,
ContextWindow: AnthropicModels[Claude35Haiku].ContextWindow,
DefaultMaxTokens: AnthropicModels[Claude35Haiku].DefaultMaxTokens,
},
OpenRouterClaude3Opus: {
ID: OpenRouterClaude3Opus,
Name: "OpenRouter Claude 3 Opus",
Provider: ProviderOpenRouter,
APIModel: "anthropic/claude-3-opus",
CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn,
CostPer1MInCached: AnthropicModels[Claude3Opus].CostPer1MInCached,
CostPer1MOut: AnthropicModels[Claude3Opus].CostPer1MOut,
CostPer1MOutCached: AnthropicModels[Claude3Opus].CostPer1MOutCached,
ContextWindow: AnthropicModels[Claude3Opus].ContextWindow,
DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens,
},
OpenRouterDeepSeekR1Free: {
ID: OpenRouterDeepSeekR1Free,
Name: "OpenRouter DeepSeek R1 Free",
Provider: ProviderOpenRouter,
APIModel: "deepseek/deepseek-r1-0528:free",
CostPer1MIn: 0,
CostPer1MInCached: 0,
CostPer1MOut: 0,
CostPer1MOutCached: 0,
ContextWindow: 163_840,
DefaultMaxTokens: 10000,
},
}

View File

@@ -1,38 +0,0 @@
package models
const (
ProviderVertexAI InferenceProvider = "vertexai"
// Models
VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash"
VertexAIGemini25 ModelID = "vertexai.gemini-2.5"
)
var VertexAIGeminiModels = map[ModelID]Model{
VertexAIGemini25Flash: {
ID: VertexAIGemini25Flash,
Name: "VertexAI: Gemini 2.5 Flash",
Provider: ProviderVertexAI,
APIModel: "gemini-2.5-flash-preview-04-17",
CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
SupportsAttachments: true,
},
VertexAIGemini25: {
ID: VertexAIGemini25,
Name: "VertexAI: Gemini 2.5 Pro",
Provider: ProviderVertexAI,
APIModel: "gemini-2.5-pro-preview-03-25",
CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
ContextWindow: GeminiModels[Gemini25].ContextWindow,
DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
SupportsAttachments: true,
},
}

View File

@@ -1,61 +0,0 @@
package models
const (
ProviderXAI InferenceProvider = "xai"
XAIGrok3Beta ModelID = "grok-3-beta"
XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
)
var XAIModels = map[ModelID]Model{
XAIGrok3Beta: {
ID: XAIGrok3Beta,
Name: "Grok3 Beta",
Provider: ProviderXAI,
APIModel: "grok-3-beta",
CostPer1MIn: 3.0,
CostPer1MInCached: 0,
CostPer1MOut: 15,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
XAIGrok3MiniBeta: {
ID: XAIGrok3MiniBeta,
Name: "Grok3 Mini Beta",
Provider: ProviderXAI,
APIModel: "grok-3-mini-beta",
CostPer1MIn: 0.3,
CostPer1MInCached: 0,
CostPer1MOut: 0.5,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
XAIGrok3FastBeta: {
ID: XAIGrok3FastBeta,
Name: "Grok3 Fast Beta",
Provider: ProviderXAI,
APIModel: "grok-3-fast-beta",
CostPer1MIn: 5,
CostPer1MInCached: 0,
CostPer1MOut: 25,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
XAiGrok3MiniFastBeta: {
ID: XAiGrok3MiniFastBeta,
Name: "Grok3 Mini Fast Beta",
Provider: ProviderXAI,
APIModel: "grok-3-mini-fast-beta",
CostPer1MIn: 0.6,
CostPer1MInCached: 0,
CostPer1MOut: 4.0,
CostPer1MOutCached: 0,
ContextWindow: 131_072,
DefaultMaxTokens: 20_000,
},
}

View File

@@ -9,19 +9,27 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
)
func CoderPrompt(provider models.InferenceProvider) string {
func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string {
basePrompt := baseAnthropicCoderPrompt
switch provider {
case models.ProviderOpenAI:
switch p {
case provider.InferenceProviderOpenAI:
basePrompt = baseOpenAICoderPrompt
}
envInfo := getEnvironmentInfo()
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
contextContent := getContextFromPaths(contextFiles)
logging.Debug("Context content", "Context", contextContent)
if contextContent != "" {
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
}
return basePrompt
}
const baseOpenAICoderPrompt = `

View File

@@ -1,60 +1,44 @@
package prompt
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/fur/provider"
)
func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string {
type PromptID string
const (
PromptCoder PromptID = "coder"
PromptTitle PromptID = "title"
PromptTask PromptID = "task"
PromptSummarizer PromptID = "summarizer"
PromptDefault PromptID = "default"
)
func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string {
basePrompt := ""
switch agentName {
case config.AgentCoder:
switch promptID {
case PromptCoder:
basePrompt = CoderPrompt(provider)
case config.AgentTitle:
case PromptTitle:
basePrompt = TitlePrompt(provider)
case config.AgentTask:
case PromptTask:
basePrompt = TaskPrompt(provider)
case config.AgentSummarizer:
case PromptSummarizer:
basePrompt = SummarizerPrompt(provider)
default:
basePrompt = "You are a helpful assistant"
}
if agentName == config.AgentCoder || agentName == config.AgentTask {
// Add context from project-specific instruction files if they exist
contextContent := getContextFromPaths()
logging.Debug("Context content", "Context", contextContent)
if contextContent != "" {
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
}
}
return basePrompt
}
var (
onceContext sync.Once
contextContent string
)
func getContextFromPaths() string {
onceContext.Do(func() {
var (
cfg = config.Get()
workDir = cfg.WorkingDir
contextPaths = cfg.ContextPaths
)
contextContent = processContextPaths(workDir, contextPaths)
})
return contextContent
func getContextFromPaths(contextPaths []string) string {
return processContextPaths(config.WorkingDirectory(), contextPaths)
}
func processContextPaths(workDir string, paths []string) string {

View File

@@ -15,16 +15,10 @@ func TestGetContextFromPaths(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
_, err := config.Load(tmpDir, false)
_, err := config.Init(tmpDir, false)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg := config.Get()
cfg.WorkingDir = tmpDir
cfg.ContextPaths = []string{
"file.txt",
"directory/",
}
testFiles := []string{
"file.txt",
"directory/file_a.txt",
@@ -34,7 +28,12 @@ func TestGetContextFromPaths(t *testing.T) {
createTestFiles(t, tmpDir, testFiles)
context := getContextFromPaths()
context := getContextFromPaths(
[]string{
"file.txt",
"directory/",
},
)
expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
assert.Equal(t, expectedContext, context)
}

View File

@@ -1,8 +1,10 @@
package prompt
import "github.com/charmbracelet/crush/internal/llm/models"
import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
func SummarizerPrompt(_ models.InferenceProvider) string {
func SummarizerPrompt(_ provider.InferenceProvider) string {
return `You are a helpful AI assistant tasked with summarizing conversations.
When asked to summarize, provide a detailed but concise summary of the conversation.

View File

@@ -3,10 +3,10 @@ package prompt
import (
"fmt"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/fur/provider"
)
func TaskPrompt(_ models.InferenceProvider) string {
func TaskPrompt(_ provider.InferenceProvider) string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".

View File

@@ -1,8 +1,10 @@
package prompt
import "github.com/charmbracelet/crush/internal/llm/models"
import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
func TitlePrompt(_ models.InferenceProvider) string {
func TitlePrompt(_ provider.InferenceProvider) string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message

View File

@@ -13,7 +13,7 @@ import (
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -59,7 +59,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
var contentBlocks []anthropic.ContentBlockParamUnion
contentBlocks = append(contentBlocks, content)
for _, binaryContent := range msg.BinaryContent() {
base64Image := binaryContent.String(models.ProviderAnthropic)
base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
contentBlocks = append(contentBlocks, imageBlock)
}
@@ -164,7 +164,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
// }
return anthropic.MessageNewParams{
Model: anthropic.Model(a.providerOptions.model.APIModel),
Model: anthropic.Model(a.providerOptions.model.ID),
MaxTokens: a.providerOptions.maxTokens,
Temperature: temperature,
Messages: messages,
@@ -184,7 +184,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
if cfg.Debug {
if cfg.Options.Debug {
jsonData, _ := json.Marshal(preparedMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -233,7 +233,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
cfg := config.Get()
if cfg.Debug {
if cfg.Options.Debug {
// jsonData, _ := json.Marshal(preparedMessages)
// logging.Debug("Prepared messages", "messages", string(jsonData))
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"os"
"strings"
"github.com/charmbracelet/crush/internal/llm/tools"
@@ -19,14 +18,8 @@ type bedrockClient struct {
type BedrockClient ProviderClient
func newBedrockClient(opts providerClientOptions) BedrockClient {
// Apply bedrock specific options if they are added in the future
// Get AWS region from environment
region := os.Getenv("AWS_REGION")
if region == "" {
region = os.Getenv("AWS_DEFAULT_REGION")
}
region := opts.extraParams["region"]
if region == "" {
region = "us-east-1" // default region
}
@@ -39,11 +32,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
// Prefix the model name with region
regionPrefix := region[:2]
modelName := opts.model.APIModel
opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
modelName := opts.model.ID
opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
// Determine which provider to use based on the model
if strings.Contains(string(opts.model.APIModel), "anthropic") {
if strings.Contains(string(opts.model.ID), "anthropic") {
// Create Anthropic client with Bedrock configuration
anthropicOpts := opts
// TODO: later find a way to check if the AWS account has caching enabled

View File

@@ -157,7 +157,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
if cfg.Debug {
if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -173,7 +173,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history)
attempts := 0
for {
@@ -245,7 +245,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
geminiMessages := g.convertMessages(messages)
cfg := config.Get()
if cfg.Debug {
if cfg.Options.Debug {
jsonData, _ := json.Marshal(geminiMessages)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -261,7 +261,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
if len(tools) > 0 {
config.Tools = g.convertTools(tools)
}
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history)
attempts := 0
eventChan := make(chan ProviderEvent)

View File

@@ -9,7 +9,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -68,7 +68,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
for _, binaryContent := range msg.BinaryContent() {
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
@@ -153,7 +153,7 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(o.providerOptions.model.APIModel),
Model: openai.ChatModel(o.providerOptions.model.ID),
Messages: messages,
Tools: tools,
}
@@ -180,7 +180,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
cfg := config.Get()
if cfg.Debug {
if cfg.Options.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
}
@@ -237,7 +237,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
}
cfg := config.Get()
if cfg.Debug {
if cfg.Options.Debug {
jsonData, _ := json.Marshal(params)
logging.Debug("Prepared messages", "messages", string(jsonData))
}

View File

@@ -3,9 +3,9 @@ package provider
import (
"context"
"fmt"
"os"
"github.com/charmbracelet/crush/internal/llm/models"
configv2 "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -55,17 +55,18 @@ type Provider interface {
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() models.Model
Model() configv2.Model
}
type providerClientOptions struct {
baseURL string
apiKey string
model models.Model
model configv2.Model
disableCache bool
maxTokens int64
systemMessage string
extraHeaders map[string]string
extraParams map[string]string
}
type ProviderClientOption func(*providerClientOptions)
@@ -80,77 +81,6 @@ type baseProvider[C ProviderClient] struct {
client C
}
func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) {
clientOptions := providerClientOptions{}
for _, o := range opts {
o(&clientOptions)
}
switch providerName {
case models.ProviderAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
client: newAnthropicClient(clientOptions, false),
}, nil
case models.ProviderOpenAI:
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderGemini:
return &baseProvider[GeminiClient]{
options: clientOptions,
client: newGeminiClient(clientOptions),
}, nil
case models.ProviderBedrock:
return &baseProvider[BedrockClient]{
options: clientOptions,
client: newBedrockClient(clientOptions),
}, nil
case models.ProviderGROQ:
clientOptions.baseURL = "https://api.groq.com/openai/v1"
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderAzure:
return &baseProvider[AzureClient]{
options: clientOptions,
client: newAzureClient(clientOptions),
}, nil
case models.ProviderVertexAI:
return &baseProvider[VertexAIClient]{
options: clientOptions,
client: newVertexAIClient(clientOptions),
}, nil
case models.ProviderOpenRouter:
clientOptions.baseURL = "https://openrouter.ai/api/v1"
clientOptions.extraHeaders = map[string]string{
"HTTP-Referer": "crush.charm.land",
"X-Title": "Crush",
}
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderXAI:
clientOptions.baseURL = "https://api.x.ai/v1"
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderLocal:
clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT")
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case models.ProviderMock:
// TODO: implement mock client for test
panic("not implemented")
}
return nil, fmt.Errorf("provider not supported: %s", providerName)
}
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
for _, msg := range messages {
// The message has no content
@@ -167,7 +97,7 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M
return p.client.send(ctx, messages, tools)
}
func (p *baseProvider[C]) Model() models.Model {
func (p *baseProvider[C]) Model() configv2.Model {
return p.options.model
}
@@ -176,7 +106,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
return p.client.stream(ctx, messages, tools)
}
func WithModel(model models.Model) ProviderClientOption {
func WithModel(model configv2.Model) ProviderClientOption {
return func(options *providerClientOptions) {
options.model = model
}
@@ -199,3 +129,53 @@ func WithSystemMessage(systemMessage string) ProviderClientOption {
options.systemMessage = systemMessage
}
}
func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
clientOptions := providerClientOptions{
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
extraHeaders: cfg.ExtraHeaders,
}
for _, o := range opts {
o(&clientOptions)
}
switch cfg.ProviderType {
case provider.TypeAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
client: newAnthropicClient(clientOptions, false),
}, nil
case provider.TypeOpenAI:
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case provider.TypeGemini:
return &baseProvider[GeminiClient]{
options: clientOptions,
client: newGeminiClient(clientOptions),
}, nil
case provider.TypeBedrock:
return &baseProvider[BedrockClient]{
options: clientOptions,
client: newBedrockClient(clientOptions),
}, nil
case provider.TypeAzure:
return &baseProvider[AzureClient]{
options: clientOptions,
client: newAzureClient(clientOptions),
}, nil
case provider.TypeVertexAI:
return &baseProvider[VertexAIClient]{
options: clientOptions,
client: newVertexAIClient(clientOptions),
}, nil
case provider.TypeXAI:
clientOptions.baseURL = "https://api.x.ai/v1"
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
}
return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
}

View File

@@ -2,7 +2,6 @@ package provider
import (
"context"
"os"
"github.com/charmbracelet/crush/internal/logging"
"google.golang.org/genai"
@@ -11,9 +10,11 @@ import (
type VertexAIClient ProviderClient
func newVertexAIClient(opts providerClientOptions) VertexAIClient {
project := opts.extraHeaders["project"]
location := opts.extraHeaders["location"]
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
Project: os.Getenv("GOOGLE_CLOUD_PROJECT"),
Location: os.Getenv("GOOGLE_CLOUD_LOCATION"),
Project: project,
Location: location,
Backend: genai.BackendVertexAI,
})
if err != nil {

View File

@@ -286,7 +286,7 @@ func (c *Client) SetServerState(state ServerState) {
// WaitForServerReady waits for the server to be ready by polling the server
// with a simple request until it responds successfully or times out
func (c *Client) WaitForServerReady(ctx context.Context) error {
cnf := config.Get()
cfg := config.Get()
// Set initial state
c.SetServerState(StateStarting)
@@ -299,7 +299,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Waiting for LSP server to be ready...")
}
@@ -308,7 +308,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
// For TypeScript-like servers, we need to open some key files first
if serverType == ServerTypeTypeScript {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("TypeScript-like server detected, opening key configuration files")
}
c.openKeyConfigFiles(ctx)
@@ -325,7 +325,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
if err == nil {
// Server responded successfully
c.SetServerState(StateReady)
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("LSP server is ready")
}
return nil
@@ -333,7 +333,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
}
}
@@ -496,7 +496,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
// openTypeScriptFiles finds and opens TypeScript files to help initialize the server
func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
cnf := config.Get()
cfg := config.Get()
filesOpened := 0
maxFilesToOpen := 5 // Limit to a reasonable number of files
@@ -526,7 +526,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
// Try to open the file
if err := c.OpenFile(ctx, path); err == nil {
filesOpened++
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Opened TypeScript file for initialization", "file", path)
}
}
@@ -535,11 +535,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
return nil
})
if err != nil && cnf.DebugLSP {
if err != nil && cfg.Options.DebugLSP {
logging.Debug("Error walking directory for TypeScript files", "error", err)
}
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Opened TypeScript files for initialization", "count", filesOpened)
}
}
@@ -664,7 +664,7 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
}
func (c *Client) CloseFile(ctx context.Context, filepath string) error {
cnf := config.Get()
cfg := config.Get()
uri := string(protocol.URIFromPath(filepath))
c.openFilesMu.Lock()
@@ -680,7 +680,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
},
}
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Closing file", "file", filepath)
}
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
@@ -704,7 +704,7 @@ func (c *Client) IsFileOpen(filepath string) bool {
// CloseAllFiles closes all currently open files
func (c *Client) CloseAllFiles(ctx context.Context) {
cnf := config.Get()
cfg := config.Get()
c.openFilesMu.Lock()
filesToClose := make([]string, 0, len(c.openFiles))
@@ -719,12 +719,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
// Then close them all
for _, filePath := range filesToClose {
err := c.CloseFile(ctx, filePath)
if err != nil && cnf.DebugLSP {
if err != nil && cfg.Options.DebugLSP {
logging.Warn("Error closing file", "file", filePath, "error", err)
}
}
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Closed all files", "files", filesToClose)
}
}

View File

@@ -82,13 +82,13 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche
// Notifications
func HandleServerMessage(params json.RawMessage) {
cnf := config.Get()
cfg := config.Get()
var msg struct {
Type int `json:"type"`
Message string `json:"message"`
}
if err := json.Unmarshal(params, &msg); err == nil {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
}
}

View File

@@ -18,9 +18,9 @@ func WriteMessage(w io.Writer, msg *Message) error {
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
cnf := config.Get()
cfg := config.Get()
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
}
@@ -39,7 +39,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
// ReadMessage reads a single LSP message from the given reader
func ReadMessage(r *bufio.Reader) (*Message, error) {
cnf := config.Get()
cfg := config.Get()
// Read headers
var contentLength int
for {
@@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
line = strings.TrimSpace(line)
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Received header", "line", line)
}
@@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
}
}
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Content-Length", "length", contentLength)
}
@@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
return nil, fmt.Errorf("failed to read content: %w", err)
}
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Received content", "content", string(content))
}
@@ -91,11 +91,11 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
// handleMessages reads and dispatches messages in a loop
func (c *Client) handleMessages() {
cnf := config.Get()
cfg := config.Get()
for {
msg, err := ReadMessage(c.stdout)
if err != nil {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Error("Error reading message", "error", err)
}
return
@@ -103,7 +103,7 @@ func (c *Client) handleMessages() {
// Handle server->client request (has both Method and ID)
if msg.Method != "" && msg.ID != 0 {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
}
@@ -157,11 +157,11 @@ func (c *Client) handleMessages() {
c.notificationMu.RUnlock()
if ok {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Handling notification", "method", msg.Method)
}
go handler(msg.Params)
} else if cnf.DebugLSP {
} else if cfg.Options.DebugLSP {
logging.Debug("No handler for notification", "method", msg.Method)
}
continue
@@ -174,12 +174,12 @@ func (c *Client) handleMessages() {
c.handlersMu.RUnlock()
if ok {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Received response for request", "id", msg.ID)
}
ch <- msg
close(ch)
} else if cnf.DebugLSP {
} else if cfg.Options.DebugLSP {
logging.Debug("No handler for response", "id", msg.ID)
}
}
@@ -188,10 +188,10 @@ func (c *Client) handleMessages() {
// Call makes a request and waits for the response
func (c *Client) Call(ctx context.Context, method string, params any, result any) error {
cnf := config.Get()
cfg := config.Get()
id := c.nextID.Add(1)
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Making call", "method", method, "id", id)
}
@@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
return fmt.Errorf("failed to send request: %w", err)
}
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Request sent", "method", method, "id", id)
}
// Wait for response
resp := <-ch
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Received response", "id", id)
}
@@ -249,8 +249,8 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
// Notify sends a notification (a request without an ID that doesn't expect a response)
func (c *Client) Notify(ctx context.Context, method string, params any) error {
cnf := config.Get()
if cnf.DebugLSP {
cfg := config.Get()
if cfg.Options.DebugLSP {
logging.Debug("Sending notification", "method", method)
}

View File

@@ -43,7 +43,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
// AddRegistrations adds file watchers to track
func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) {
cnf := config.Get()
cfg := config.Get()
logging.Debug("Adding file watcher registrations")
w.registrationMu.Lock()
@@ -53,7 +53,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
w.registrations = append(w.registrations, watchers...)
// Print detailed registration information for debugging
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Adding file watcher registrations",
"id", id,
"watchers", len(watchers),
@@ -122,7 +122,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
highPriorityFilesOpened := w.openHighPriorityFiles(ctx, serverName)
filesOpened += highPriorityFilesOpened
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Opened high-priority files",
"count", highPriorityFilesOpened,
"serverName", serverName)
@@ -130,7 +130,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// If we've already opened enough high-priority files, we might not need more
if filesOpened >= maxFilesToOpen {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Reached file limit with high-priority files",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen)
@@ -148,7 +148,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// Skip directories that should be excluded
if d.IsDir() {
if path != w.workspacePath && shouldExcludeDir(path) {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -176,7 +176,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
})
elapsedTime := time.Since(startTime)
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Limited workspace scan complete",
"filesOpened", filesOpened,
"maxFiles", maxFilesToOpen,
@@ -185,11 +185,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
)
}
if err != nil && cnf.DebugLSP {
if err != nil && cfg.Options.DebugLSP {
logging.Debug("Error scanning workspace for files to open", "error", err)
}
}()
} else if cnf.DebugLSP {
} else if cfg.Options.DebugLSP {
logging.Debug("Using on-demand file loading for server", "server", serverName)
}
}
@@ -197,7 +197,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
// openHighPriorityFiles opens important files for the server type
// Returns the number of files opened
func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName string) int {
cnf := config.Get()
cfg := config.Get()
filesOpened := 0
// Define patterns for high-priority files based on server type
@@ -265,7 +265,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
// Use doublestar.Glob to find files matching the pattern (supports ** patterns)
matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern)
if err != nil {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
}
continue
@@ -299,12 +299,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
for j := i; j < end; j++ {
fullPath := filesToOpen[j]
if err := w.client.OpenFile(ctx, fullPath); err != nil {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Error opening high-priority file", "path", fullPath, "error", err)
}
} else {
filesOpened++
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Opened high-priority file", "path", fullPath)
}
}
@@ -321,7 +321,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
// WatchWorkspace sets up file watching for a workspace
func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath string) {
cnf := config.Get()
cfg := config.Get()
w.workspacePath = workspacePath
// Store the watcher in the context for later use
@@ -356,7 +356,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
// Skip excluded directories (except workspace root)
if d.IsDir() && path != workspacePath {
if shouldExcludeDir(path) {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Skipping excluded directory", "path", path)
}
return filepath.SkipDir
@@ -409,7 +409,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
}
// Debug logging
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
matched, kind := w.isPathWatched(event.Name)
logging.Debug("File event",
"path", event.Name,
@@ -676,8 +676,8 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
// notifyFileEvent sends a didChangeWatchedFiles notification for a file event
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
cnf := config.Get()
if cnf.DebugLSP {
cfg := config.Get()
if cfg.Options.DebugLSP {
logging.Debug("Notifying file event",
"uri", uri,
"changeType", changeType,
@@ -826,7 +826,7 @@ func shouldExcludeDir(dirPath string) bool {
// shouldExcludeFile returns true if the file should be excluded from opening
func shouldExcludeFile(filePath string) bool {
fileName := filepath.Base(filePath)
cnf := config.Get()
cfg := config.Get()
// Skip dot files
if strings.HasPrefix(fileName, ".") {
return true
@@ -852,12 +852,12 @@ func shouldExcludeFile(filePath string) bool {
// Skip large files
if info.Size() > maxFileSize {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Skipping large file",
"path", filePath,
"size", info.Size(),
"maxSize", maxFileSize,
"debug", cnf.Debug,
"debug", cfg.Options.Debug,
"sizeMB", float64(info.Size())/(1024*1024),
"maxSizeMB", float64(maxFileSize)/(1024*1024),
)
@@ -870,7 +870,7 @@ func shouldExcludeFile(filePath string) bool {
// openMatchingFile opens a file if it matches any of the registered patterns
func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
cnf := config.Get()
cfg := config.Get()
// Skip directories
info, err := os.Stat(path)
if err != nil || info.IsDir() {
@@ -890,10 +890,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check if the file is a high-priority file that should be opened immediately
// This helps with project initialization for certain language servers
if isHighPriorityFile(path, serverName) {
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Opening high-priority file", "path", path, "serverName", serverName)
}
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
logging.Error("Error opening high-priority file", "path", path, "error", err)
}
return
@@ -905,7 +905,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
// Check file size - for preloading we're more conservative
if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
if cnf.DebugLSP {
if cfg.Options.DebugLSP {
logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
}
return
@@ -937,7 +937,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
if shouldOpen {
// Don't need to check if it's already open - the client.OpenFile handles that
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
logging.Error("Error opening file", "path", path, "error", err)
}
}

View File

@@ -5,7 +5,7 @@ import (
"slices"
"time"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/fur/provider"
)
type MessageRole string
@@ -71,9 +71,9 @@ type BinaryContent struct {
Data []byte
}
func (bc BinaryContent) String(provider models.InferenceProvider) string {
func (bc BinaryContent) String(p provider.InferenceProvider) string {
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
if provider == models.ProviderOpenAI {
if p == provider.InferenceProviderOpenAI {
return "data:" + bc.MIMEType + ";base64," + base64Encoded
}
return base64Encoded
@@ -113,7 +113,8 @@ type Message struct {
Role MessageRole
SessionID string
Parts []ContentPart
Model models.ModelID
Model string
Provider string
CreatedAt int64
UpdatedAt int64
}

View File

@@ -8,15 +8,15 @@ import (
"time"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
type CreateMessageParams struct {
Role MessageRole
Parts []ContentPart
Model models.ModelID
Role MessageRole
Parts []ContentPart
Model string
Provider string
}
type Service interface {
@@ -70,6 +70,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
Role: string(params.Role),
Parts: string(partsJSON),
Model: sql.NullString{String: string(params.Model), Valid: true},
Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
})
if err != nil {
return Message{}, err
@@ -154,7 +155,8 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
SessionID: item.SessionID,
Role: MessageRole(item.Role),
Parts: parts,
Model: models.ModelID(item.Model.String),
Model: item.Model.String,
Provider: item.Provider.String,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}, nil

View File

@@ -7,7 +7,6 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -112,11 +111,7 @@ func (h *header) details() string {
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
}
cfg := config.Get()
agentCfg := cfg.Agents[config.AgentCoder]
selectedModelID := agentCfg.Model
model := models.SupportedModels[selectedModelID]
model := config.GetAgentModel(config.AgentCoder)
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
parts = append(parts, formattedPercentage)

View File

@@ -10,7 +10,8 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/tui/components/anim"
"github.com/charmbracelet/crush/internal/tui/components/core"
@@ -290,8 +291,9 @@ func (m *assistantSectionModel) View() tea.View {
duration := finishTime.Sub(m.lastUserMessageTime)
infoMsg := t.S().Subtle.Render(duration.String())
icon := t.S().Subtle.Render(styles.ModelIcon)
model := t.S().Muted.Render(models.SupportedModels[m.message.Model].Name)
assistant := fmt.Sprintf("%s %s %s", icon, model, infoMsg)
model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model)
modelFormatted := t.S().Muted.Render(model.Name)
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
return tea.NewView(
t.S().Base.PaddingLeft(2).Render(
core.Section(assistant, m.width-2),

View File

@@ -13,7 +13,6 @@ import (
"github.com/charmbracelet/crush/internal/diff"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
@@ -406,7 +405,7 @@ func (m *sidebarCmp) mcpBlock() string {
mcpList := []string{section, ""}
mcp := config.Get().MCPServers
mcp := config.Get().MCP
if len(mcp) == 0 {
return lipgloss.JoinVertical(
lipgloss.Left,
@@ -475,10 +474,7 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
}
func (s *sidebarCmp) currentModelBlock() string {
cfg := config.Get()
agentCfg := cfg.Agents[config.AgentCoder]
selectedModelID := agentCfg.Model
model := models.SupportedModels[selectedModelID]
model := config.GetAgentModel(config.AgentCoder)
t := styles.CurrentTheme()

View File

@@ -63,7 +63,7 @@ func buildCommandSources(cfg *config.Config) []commandSource {
// Project directory
sources = append(sources, commandSource{
path: filepath.Join(cfg.Data.Directory, "commands"),
path: filepath.Join(cfg.Options.DataDirectory, "commands"),
prefix: ProjectCommandPrefix,
})

View File

@@ -5,7 +5,7 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/crush/internal/config"
configv2 "github.com/charmbracelet/crush/internal/config"
cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
@@ -184,7 +184,7 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (
Add the .crush directory to the .gitignore file if it's not already there.`
// Mark the project as initialized
if err := config.MarkProjectInitialized(); err != nil {
if err := configv2.MarkProjectInitialized(); err != nil {
return util.ReportError(err)
}
@@ -196,7 +196,7 @@ Add the .crush directory to the .gitignore file if it's not already there.`
)
} else {
// Mark the project as initialized without running the command
if err := config.MarkProjectInitialized(); err != nil {
if err := configv2.MarkProjectInitialized(); err != nil {
return util.ReportError(err)
}
}

View File

@@ -1,13 +1,11 @@
package models
import (
"slices"
"github.com/charmbracelet/bubbles/v2/help"
"github.com/charmbracelet/bubbles/v2/key"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
configv2 "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/tui/components/completions"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/list"
@@ -26,7 +24,7 @@ const (
// ModelSelectedMsg is sent when a model is selected
type ModelSelectedMsg struct {
Model models.Model
Model configv2.PreferredModel
}
// CloseModelDialogMsg is sent when a model is selected
@@ -37,6 +35,11 @@ type ModelDialog interface {
dialogs.DialogModel
}
type ModelOption struct {
Provider provider.Provider
Model provider.Model
}
type modelDialogCmp struct {
width int
wWidth int // Width of the terminal window
@@ -80,47 +83,31 @@ func NewModelDialogCmp() ModelDialog {
}
}
var ProviderPopularity = map[models.InferenceProvider]int{
models.ProviderAnthropic: 1,
models.ProviderOpenAI: 2,
models.ProviderGemini: 3,
models.ProviderGROQ: 4,
models.ProviderOpenRouter: 5,
models.ProviderBedrock: 6,
models.ProviderAzure: 7,
models.ProviderVertexAI: 8,
models.ProviderXAI: 9,
}
var ProviderName = map[models.InferenceProvider]string{
models.ProviderAnthropic: "Anthropic",
models.ProviderOpenAI: "OpenAI",
models.ProviderGemini: "Gemini",
models.ProviderGROQ: "Groq",
models.ProviderOpenRouter: "OpenRouter",
models.ProviderBedrock: "AWS Bedrock",
models.ProviderAzure: "Azure",
models.ProviderVertexAI: "VertexAI",
models.ProviderXAI: "xAI",
}
func (m *modelDialogCmp) Init() tea.Cmd {
cfg := config.Get()
enabledProviders := getEnabledProviders(cfg)
providers := configv2.Providers()
cfg := configv2.Get()
coderAgent := cfg.Agents[configv2.AgentCoder]
modelItems := []util.Model{}
for _, provider := range enabledProviders {
name, ok := ProviderName[provider]
if !ok {
name = string(provider) // Fallback to provider ID if name is not defined
selectIndex := 0
for _, provider := range providers {
name := provider.Name
if name == "" {
name = string(provider.ID)
}
modelItems = append(modelItems, commands.NewItemSection(name))
for _, model := range getModelsForProvider(provider) {
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
for _, model := range provider.Models {
if model.ID == coderAgent.Model && provider.ID == coderAgent.Provider {
selectIndex = len(modelItems) // Set the selected index to the current model
}
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
Provider: provider,
Model: model,
}))
}
}
m.modelList.SetItems(modelItems)
return m.modelList.Init()
return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
}
func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -137,11 +124,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil // No item selected, do nothing
}
items := m.modelList.Items()
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
return m, tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
util.CmdHandler(ModelSelectedMsg{Model: configv2.PreferredModel{
ModelID: selectedItem.Model.ID,
Provider: selectedItem.Provider.ID,
}}),
)
case key.Matches(msg, m.keyMap.Close):
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
@@ -189,58 +179,6 @@ func (m *modelDialogCmp) listHeight() int {
return min(listHeigh, m.wHeight/2)
}
func GetSelectedModel(cfg *config.Config) models.Model {
agentCfg := cfg.Agents[config.AgentCoder]
selectedModelID := agentCfg.Model
return models.SupportedModels[selectedModelID]
}
func getEnabledProviders(cfg *config.Config) []models.InferenceProvider {
var providers []models.InferenceProvider
for providerID, provider := range cfg.Providers {
if !provider.Disabled {
providers = append(providers, providerID)
}
}
// Sort by provider popularity
slices.SortFunc(providers, func(a, b models.InferenceProvider) int {
rA := ProviderPopularity[a]
rB := ProviderPopularity[b]
// models not included in popularity ranking default to last
if rA == 0 {
rA = 999
}
if rB == 0 {
rB = 999
}
return rA - rB
})
return providers
}
func getModelsForProvider(provider models.InferenceProvider) []models.Model {
var providerModels []models.Model
for _, model := range models.SupportedModels {
if model.Provider == provider {
providerModels = append(providerModels, model)
}
}
// reverse alphabetical order (if llm naming was consistent latest would appear first)
slices.SortFunc(providerModels, func(a, b models.Model) int {
if a.Name > b.Name {
return -1
} else if a.Name < b.Name {
return 1
}
return 0
})
return providerModels
}
func (m *modelDialogCmp) Position() (int, int) {
row := m.wHeight/4 - 2 // just a bit above the center
col := m.wWidth / 2

View File

@@ -9,7 +9,6 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/models"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/tui/components/chat"
@@ -171,14 +170,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
util.CmdHandler(ChatFocusedMsg{Focused: false}),
)
case key.Matches(msg, p.keyMap.AddAttachment):
cfg := config.Get()
agentCfg := cfg.Agents[config.AgentCoder]
selectedModelID := agentCfg.Model
model := models.SupportedModels[selectedModelID]
if model.SupportsAttachments {
model := config.GetAgentModel(config.AgentCoder)
if model.SupportsImages {
return p, util.CmdHandler(OpenFilePickerMsg{})
} else {
return p, util.ReportWarn("File attachments are not supported by the current model: " + string(selectedModelID))
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name)
}
case key.Matches(msg, p.keyMap.Tab):
if p.session.ID == "" {

View File

@@ -8,6 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
configv2 "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/agent"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/permission"
@@ -69,7 +70,7 @@ func (a appModel) Init() tea.Cmd {
// Check if we should show the init dialog
cmds = append(cmds, func() tea.Msg {
shouldShow, err := config.ShouldShowInitDialog()
shouldShow, err := configv2.ProjectNeedsInitialization()
if err != nil {
return util.InfoMsg{
Type: util.InfoTypeError,
@@ -172,7 +173,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Model Switch
case models.ModelSelectedMsg:
model, err := a.app.CoderAgent.Update(config.AgentCoder, msg.Model.ID)
model, err := a.app.CoderAgent.Update(msg.Model)
if err != nil {
return a, util.ReportError(err)
}
@@ -222,7 +223,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
model := a.app.CoderAgent.Model()
contextWindow := model.ContextWindow
tokens := session.CompletionTokens + session.PromptTokens
if (tokens >= int64(float64(contextWindow)*0.95)) && config.Get().AutoCompact {
if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize {
// Show compact confirmation dialog
cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{
Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false),