mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
chore: move to the new config
This commit is contained in:
@@ -72,7 +72,8 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
}
|
||||
cwd = c
|
||||
}
|
||||
_, err := config.Load(cwd, debug)
|
||||
|
||||
_, err := config.Init(cwd, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// TODO: FIX THIS
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -6,7 +7,6 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
)
|
||||
|
||||
// JSONSchemaType represents a JSON Schema type
|
||||
@@ -192,22 +192,10 @@ func generateSchema() map[string]any {
|
||||
},
|
||||
}
|
||||
|
||||
// Add known providers
|
||||
knownProviders := []string{
|
||||
string(models.ProviderAnthropic),
|
||||
string(models.ProviderOpenAI),
|
||||
string(models.ProviderGemini),
|
||||
string(models.ProviderGROQ),
|
||||
string(models.ProviderOpenRouter),
|
||||
string(models.ProviderBedrock),
|
||||
string(models.ProviderAzure),
|
||||
string(models.ProviderVertexAI),
|
||||
}
|
||||
|
||||
providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{
|
||||
"type": "string",
|
||||
"description": "Provider type",
|
||||
"enum": knownProviders,
|
||||
"enum": []string{},
|
||||
}
|
||||
|
||||
schema["properties"].(map[string]any)["providers"] = providerSchema
|
||||
@@ -241,9 +229,7 @@ func generateSchema() map[string]any {
|
||||
|
||||
// Add model enum
|
||||
modelEnum := []string{}
|
||||
for modelID := range models.SupportedModels {
|
||||
modelEnum = append(modelEnum, string(modelID))
|
||||
}
|
||||
|
||||
agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum
|
||||
|
||||
// Add specific agent properties
|
||||
@@ -251,7 +237,6 @@ func generateSchema() map[string]any {
|
||||
knownAgents := []string{
|
||||
string(config.AgentCoder),
|
||||
string(config.AgentTask),
|
||||
string(config.AgentTitle),
|
||||
}
|
||||
|
||||
for _, agentName := range knownAgents {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
configv2 "github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/db"
|
||||
"github.com/charmbracelet/crush/internal/format"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
@@ -55,18 +55,21 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
|
||||
// Initialize LSP clients in the background
|
||||
go app.initLSPClients(ctx)
|
||||
|
||||
cfg := configv2.Get()
|
||||
|
||||
coderAgentCfg := cfg.Agents[configv2.AgentCoder]
|
||||
if coderAgentCfg.ID == "" {
|
||||
return nil, fmt.Errorf("coder agent configuration is missing")
|
||||
}
|
||||
|
||||
var err error
|
||||
app.CoderAgent, err = agent.NewAgent(
|
||||
config.AgentCoder,
|
||||
coderAgentCfg,
|
||||
app.Permissions,
|
||||
app.Sessions,
|
||||
app.Messages,
|
||||
agent.CoderAgentTools(
|
||||
app.Permissions,
|
||||
app.Sessions,
|
||||
app.Messages,
|
||||
app.History,
|
||||
app.LSPClients,
|
||||
),
|
||||
app.History,
|
||||
app.LSPClients,
|
||||
)
|
||||
if err != nil {
|
||||
logging.Error("Failed to create coder agent", err)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
package configv2
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -28,7 +28,7 @@ func TestConfigWithEnv(t *testing.T) {
|
||||
os.Setenv("GEMINI_API_KEY", "test-gemini-key")
|
||||
os.Setenv("XAI_API_KEY", "test-xai-key")
|
||||
os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
|
||||
cfg := InitConfig(cwdDir)
|
||||
cfg, _ := Init(cwdDir, false)
|
||||
data, _ := json.MarshalIndent(cfg, "", " ")
|
||||
fmt.Println(string(data))
|
||||
assert.Len(t, cfg.Providers, 5)
|
||||
@@ -1,4 +1,4 @@
|
||||
package configv2
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -17,23 +17,20 @@ type ProjectInitFlag struct {
|
||||
Initialized bool `json:"initialized"`
|
||||
}
|
||||
|
||||
// ShouldShowInitDialog checks if the initialization dialog should be shown for the current directory
|
||||
func ShouldShowInitDialog() (bool, error) {
|
||||
if cfg == nil {
|
||||
// ProjectNeedsInitialization checks if the current project needs initialization
|
||||
func ProjectNeedsInitialization() (bool, error) {
|
||||
if instance == nil {
|
||||
return false, fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
// Create the flag file path
|
||||
flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
|
||||
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
|
||||
|
||||
// Check if the flag file exists
|
||||
_, err := os.Stat(flagFilePath)
|
||||
if err == nil {
|
||||
// File exists, don't show the dialog
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// If the error is not "file not found", return the error
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("failed to check init flag file: %w", err)
|
||||
}
|
||||
@@ -44,11 +41,9 @@ func ShouldShowInitDialog() (bool, error) {
|
||||
return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err)
|
||||
}
|
||||
if crushExists {
|
||||
// CRUSH.md already exists, don't show the dialog
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// File doesn't exist, show the dialog
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -75,13 +70,11 @@ func crushMdExists(dir string) (bool, error) {
|
||||
|
||||
// MarkProjectInitialized marks the current project as initialized
|
||||
func MarkProjectInitialized() error {
|
||||
if cfg == nil {
|
||||
if instance == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
// Create the flag file path
|
||||
flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename)
|
||||
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
|
||||
|
||||
// Create an empty file to mark the project as initialized
|
||||
file, err := os.Create(flagFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create init flag file: %w", err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package configv2
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,660 +0,0 @@
|
||||
package configv2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDataDirectory = ".crush"
|
||||
defaultLogLevel = "info"
|
||||
appName = "crush"
|
||||
|
||||
MaxTokensFallbackDefault = 4096
|
||||
)
|
||||
|
||||
var defaultContextPaths = []string{
|
||||
".github/copilot-instructions.md",
|
||||
".cursorrules",
|
||||
".cursor/rules/",
|
||||
"CLAUDE.md",
|
||||
"CLAUDE.local.md",
|
||||
"crush.md",
|
||||
"crush.local.md",
|
||||
"Crush.md",
|
||||
"Crush.local.md",
|
||||
"CRUSH.md",
|
||||
"CRUSH.local.md",
|
||||
}
|
||||
|
||||
type AgentID string
|
||||
|
||||
const (
|
||||
AgentCoder AgentID = "coder"
|
||||
AgentTask AgentID = "task"
|
||||
AgentTitle AgentID = "title"
|
||||
AgentSummarize AgentID = "summarize"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
DefaultMaxTokens int64 `json:"default_max_tokens"`
|
||||
CanReason bool `json:"can_reason"`
|
||||
ReasoningEffort string `json:"reasoning_effort"`
|
||||
SupportsImages bool `json:"supports_attachments"`
|
||||
}
|
||||
|
||||
type VertexAIOptions struct {
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Location string `json:"location,omitempty"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
ID provider.InferenceProvider `json:"id"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
ProviderType provider.Type `json:"provider_type"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
Disabled bool `json:"disabled"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
|
||||
// used for e.x for vertex to set the project
|
||||
ExtraParams map[string]string `json:"extra_params,omitempty"`
|
||||
|
||||
DefaultLargeModel string `json:"default_large_model,omitempty"`
|
||||
DefaultSmallModel string `json:"default_small_model,omitempty"`
|
||||
|
||||
Models []Model `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
type Agent struct {
|
||||
ID AgentID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
// This is the id of the system prompt used by the agent
|
||||
Disabled bool `json:"disabled"`
|
||||
|
||||
Provider provider.InferenceProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
|
||||
// The available tools for the agent
|
||||
// if this is nil, all tools are available
|
||||
AllowedTools []string `json:"allowed_tools"`
|
||||
|
||||
// this tells us which MCPs are available for this agent
|
||||
// if this is empty all mcps are available
|
||||
// the string array is the list of tools from the AllowedMCP the agent has available
|
||||
// if the string array is nil, all tools from the AllowedMCP are available
|
||||
AllowedMCP map[string][]string `json:"allowed_mcp"`
|
||||
|
||||
// The list of LSPs that this agent can use
|
||||
// if this is nil, all LSPs are available
|
||||
AllowedLSP []string `json:"allowed_lsp"`
|
||||
|
||||
// Overrides the context paths for this agent
|
||||
ContextPaths []string `json:"context_paths"`
|
||||
}
|
||||
|
||||
type MCPType string
|
||||
|
||||
const (
|
||||
MCPStdio MCPType = "stdio"
|
||||
MCPSse MCPType = "sse"
|
||||
)
|
||||
|
||||
type MCP struct {
|
||||
Command string `json:"command"`
|
||||
Env []string `json:"env"`
|
||||
Args []string `json:"args"`
|
||||
Type MCPType `json:"type"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
}
|
||||
|
||||
type LSPConfig struct {
|
||||
Disabled bool `json:"enabled"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Options any `json:"options"`
|
||||
}
|
||||
|
||||
type TUIOptions struct {
|
||||
CompactMode bool `json:"compact_mode"`
|
||||
// Here we can add themes later or any TUI related options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
ContextPaths []string `json:"context_paths"`
|
||||
TUI TUIOptions `json:"tui"`
|
||||
Debug bool `json:"debug"`
|
||||
DebugLSP bool `json:"debug_lsp"`
|
||||
DisableAutoSummarize bool `json:"disable_auto_summarize"`
|
||||
// Relative to the cwd
|
||||
DataDirectory string `json:"data_directory"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// List of configured providers
|
||||
Providers map[provider.InferenceProvider]ProviderConfig `json:"providers,omitempty"`
|
||||
|
||||
// List of configured agents
|
||||
Agents map[AgentID]Agent `json:"agents,omitempty"`
|
||||
|
||||
// List of configured MCPs
|
||||
MCP map[string]MCP `json:"mcp,omitempty"`
|
||||
|
||||
// List of configured LSPs
|
||||
LSP map[string]LSPConfig `json:"lsp,omitempty"`
|
||||
|
||||
// Miscellaneous options
|
||||
Options Options `json:"options"`
|
||||
}
|
||||
|
||||
var (
|
||||
instance *Config // The single instance of the Singleton
|
||||
cwd string
|
||||
once sync.Once // Ensures the initialization happens only once
|
||||
|
||||
)
|
||||
|
||||
func loadConfig(cwd string) (*Config, error) {
|
||||
// First read the global config file
|
||||
cfgPath := ConfigPath()
|
||||
|
||||
cfg := defaultConfigBasedOnEnv()
|
||||
|
||||
var globalCfg *Config
|
||||
if _, err := os.Stat(cfgPath); err != nil && !os.IsNotExist(err) {
|
||||
// some other error occurred while checking the file
|
||||
return nil, err
|
||||
} else if err == nil {
|
||||
// config file exists, read it
|
||||
file, err := os.ReadFile(cfgPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
globalCfg = &Config{}
|
||||
if err := json.Unmarshal(file, globalCfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// config file does not exist, create a new one
|
||||
globalCfg = &Config{}
|
||||
}
|
||||
|
||||
var localConfig *Config
|
||||
// Global config loaded, now read the local config file
|
||||
localConfigPath := filepath.Join(cwd, "crush.json")
|
||||
if _, err := os.Stat(localConfigPath); err != nil && !os.IsNotExist(err) {
|
||||
// some other error occurred while checking the file
|
||||
return nil, err
|
||||
} else if err == nil {
|
||||
// local config file exists, read it
|
||||
file, err := os.ReadFile(localConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localConfig = &Config{}
|
||||
if err := json.Unmarshal(file, localConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// merge options
|
||||
mergeOptions(cfg, globalCfg, localConfig)
|
||||
|
||||
mergeProviderConfigs(cfg, globalCfg, localConfig)
|
||||
// no providers found the app is not initialized yet
|
||||
if len(cfg.Providers) == 0 {
|
||||
return cfg, nil
|
||||
}
|
||||
preferredProvider := getPreferredProvider(cfg.Providers)
|
||||
|
||||
if preferredProvider == nil {
|
||||
return nil, errors.New("no valid providers configured")
|
||||
}
|
||||
|
||||
agents := map[AgentID]Agent{
|
||||
AgentCoder: {
|
||||
ID: AgentCoder,
|
||||
Name: "Coder",
|
||||
Description: "An agent that helps with executing coding tasks.",
|
||||
Provider: preferredProvider.ID,
|
||||
Model: preferredProvider.DefaultLargeModel,
|
||||
ContextPaths: cfg.Options.ContextPaths,
|
||||
// All tools allowed
|
||||
},
|
||||
AgentTask: {
|
||||
ID: AgentTask,
|
||||
Name: "Task",
|
||||
Description: "An agent that helps with searching for context and finding implementation details.",
|
||||
Provider: preferredProvider.ID,
|
||||
Model: preferredProvider.DefaultLargeModel,
|
||||
ContextPaths: cfg.Options.ContextPaths,
|
||||
AllowedTools: []string{
|
||||
"glob",
|
||||
"grep",
|
||||
"ls",
|
||||
"sourcegraph",
|
||||
"view",
|
||||
},
|
||||
// NO MCPs or LSPs by default
|
||||
AllowedMCP: map[string][]string{},
|
||||
AllowedLSP: []string{},
|
||||
},
|
||||
AgentTitle: {
|
||||
ID: AgentTitle,
|
||||
Name: "Title",
|
||||
Description: "An agent that helps with generating titles for sessions.",
|
||||
Provider: preferredProvider.ID,
|
||||
Model: preferredProvider.DefaultSmallModel,
|
||||
ContextPaths: cfg.Options.ContextPaths,
|
||||
AllowedTools: []string{},
|
||||
// NO MCPs or LSPs by default
|
||||
AllowedMCP: map[string][]string{},
|
||||
AllowedLSP: []string{},
|
||||
},
|
||||
AgentSummarize: {
|
||||
ID: AgentSummarize,
|
||||
Name: "Summarize",
|
||||
Description: "An agent that helps with summarizing sessions.",
|
||||
Provider: preferredProvider.ID,
|
||||
Model: preferredProvider.DefaultSmallModel,
|
||||
ContextPaths: cfg.Options.ContextPaths,
|
||||
AllowedTools: []string{},
|
||||
// NO MCPs or LSPs by default
|
||||
AllowedMCP: map[string][]string{},
|
||||
AllowedLSP: []string{},
|
||||
},
|
||||
}
|
||||
cfg.Agents = agents
|
||||
mergeAgents(cfg, globalCfg, localConfig)
|
||||
mergeMCPs(cfg, globalCfg, localConfig)
|
||||
mergeLSPs(cfg, globalCfg, localConfig)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func InitConfig(workingDir string) *Config {
|
||||
once.Do(func() {
|
||||
cwd = workingDir
|
||||
cfg, err := loadConfig(cwd)
|
||||
if err != nil {
|
||||
// TODO: Handle this better
|
||||
panic("Failed to load config: " + err.Error())
|
||||
}
|
||||
instance = cfg
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func GetConfig() *Config {
|
||||
if instance == nil {
|
||||
// TODO: Handle this better
|
||||
panic("Config not initialized. Call InitConfig first.")
|
||||
}
|
||||
return instance
|
||||
}
|
||||
|
||||
func getPreferredProvider(configuredProviders map[provider.InferenceProvider]ProviderConfig) *ProviderConfig {
|
||||
providers := Providers()
|
||||
for _, p := range providers {
|
||||
if providerConfig, ok := configuredProviders[p.ID]; ok && !providerConfig.Disabled {
|
||||
return &providerConfig
|
||||
}
|
||||
}
|
||||
// if none found return the first configured provider
|
||||
for _, providerConfig := range configuredProviders {
|
||||
if !providerConfig.Disabled {
|
||||
return &providerConfig
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeProviderConfig(p provider.InferenceProvider, base, other ProviderConfig) ProviderConfig {
|
||||
if other.APIKey != "" {
|
||||
base.APIKey = other.APIKey
|
||||
}
|
||||
// Only change these options if the provider is not a known provider
|
||||
if !slices.Contains(provider.KnownProviders(), p) {
|
||||
if other.BaseURL != "" {
|
||||
base.BaseURL = other.BaseURL
|
||||
}
|
||||
if other.ProviderType != "" {
|
||||
base.ProviderType = other.ProviderType
|
||||
}
|
||||
if len(base.ExtraHeaders) > 0 {
|
||||
if base.ExtraHeaders == nil {
|
||||
base.ExtraHeaders = make(map[string]string)
|
||||
}
|
||||
maps.Copy(base.ExtraHeaders, other.ExtraHeaders)
|
||||
}
|
||||
if len(other.ExtraParams) > 0 {
|
||||
if base.ExtraParams == nil {
|
||||
base.ExtraParams = make(map[string]string)
|
||||
}
|
||||
maps.Copy(base.ExtraParams, other.ExtraParams)
|
||||
}
|
||||
}
|
||||
|
||||
if other.Disabled {
|
||||
base.Disabled = other.Disabled
|
||||
}
|
||||
|
||||
if other.DefaultLargeModel != "" {
|
||||
base.DefaultLargeModel = other.DefaultLargeModel
|
||||
}
|
||||
// Add new models if they don't exist
|
||||
if other.Models != nil {
|
||||
for _, model := range other.Models {
|
||||
// check if the model already exists
|
||||
exists := false
|
||||
for _, existingModel := range base.Models {
|
||||
if existingModel.ID == model.ID {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
base.Models = append(base.Models, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
func validateProvider(p provider.InferenceProvider, providerConfig ProviderConfig) error {
|
||||
if !slices.Contains(provider.KnownProviders(), p) {
|
||||
if providerConfig.ProviderType != provider.TypeOpenAI {
|
||||
return errors.New("invalid provider type: " + string(providerConfig.ProviderType))
|
||||
}
|
||||
if providerConfig.BaseURL == "" {
|
||||
return errors.New("base URL must be set for custom providers")
|
||||
}
|
||||
if providerConfig.APIKey == "" {
|
||||
return errors.New("API key must be set for custom providers")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeOptions(base, global, local *Config) {
|
||||
for _, cfg := range []*Config{global, local} {
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
baseOptions := base.Options
|
||||
other := cfg.Options
|
||||
if len(other.ContextPaths) > 0 {
|
||||
baseOptions.ContextPaths = append(baseOptions.ContextPaths, other.ContextPaths...)
|
||||
}
|
||||
|
||||
if other.TUI.CompactMode {
|
||||
baseOptions.TUI.CompactMode = other.TUI.CompactMode
|
||||
}
|
||||
|
||||
if other.Debug {
|
||||
baseOptions.Debug = other.Debug
|
||||
}
|
||||
|
||||
if other.DebugLSP {
|
||||
baseOptions.DebugLSP = other.DebugLSP
|
||||
}
|
||||
|
||||
if other.DisableAutoSummarize {
|
||||
baseOptions.DisableAutoSummarize = other.DisableAutoSummarize
|
||||
}
|
||||
|
||||
if other.DataDirectory != "" {
|
||||
baseOptions.DataDirectory = other.DataDirectory
|
||||
}
|
||||
base.Options = baseOptions
|
||||
}
|
||||
}
|
||||
|
||||
func mergeAgents(base, global, local *Config) {
|
||||
for _, cfg := range []*Config{global, local} {
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
for agentID, newAgent := range cfg.Agents {
|
||||
if _, ok := base.Agents[agentID]; !ok {
|
||||
newAgent.ID = agentID // Ensure the ID is set correctly
|
||||
base.Agents[agentID] = newAgent
|
||||
} else {
|
||||
switch agentID {
|
||||
case AgentCoder:
|
||||
baseAgent := base.Agents[agentID]
|
||||
baseAgent.Model = newAgent.Model
|
||||
baseAgent.Provider = newAgent.Provider
|
||||
baseAgent.AllowedMCP = newAgent.AllowedMCP
|
||||
baseAgent.AllowedLSP = newAgent.AllowedLSP
|
||||
base.Agents[agentID] = baseAgent
|
||||
case AgentTask:
|
||||
baseAgent := base.Agents[agentID]
|
||||
baseAgent.Model = newAgent.Model
|
||||
baseAgent.Provider = newAgent.Provider
|
||||
base.Agents[agentID] = baseAgent
|
||||
case AgentTitle:
|
||||
baseAgent := base.Agents[agentID]
|
||||
baseAgent.Model = newAgent.Model
|
||||
baseAgent.Provider = newAgent.Provider
|
||||
base.Agents[agentID] = baseAgent
|
||||
case AgentSummarize:
|
||||
baseAgent := base.Agents[agentID]
|
||||
baseAgent.Model = newAgent.Model
|
||||
baseAgent.Provider = newAgent.Provider
|
||||
base.Agents[agentID] = baseAgent
|
||||
default:
|
||||
baseAgent := base.Agents[agentID]
|
||||
baseAgent.Name = newAgent.Name
|
||||
baseAgent.Description = newAgent.Description
|
||||
baseAgent.Disabled = newAgent.Disabled
|
||||
baseAgent.Provider = newAgent.Provider
|
||||
baseAgent.Model = newAgent.Model
|
||||
baseAgent.AllowedTools = newAgent.AllowedTools
|
||||
baseAgent.AllowedMCP = newAgent.AllowedMCP
|
||||
baseAgent.AllowedLSP = newAgent.AllowedLSP
|
||||
base.Agents[agentID] = baseAgent
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeMCPs(base, global, local *Config) {
|
||||
for _, cfg := range []*Config{global, local} {
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
maps.Copy(base.MCP, cfg.MCP)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeLSPs(base, global, local *Config) {
|
||||
for _, cfg := range []*Config{global, local} {
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
maps.Copy(base.LSP, cfg.LSP)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeProviderConfigs(base, global, local *Config) {
|
||||
for _, cfg := range []*Config{global, local} {
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
for providerName, globalProvider := range cfg.Providers {
|
||||
if _, ok := base.Providers[providerName]; !ok {
|
||||
base.Providers[providerName] = globalProvider
|
||||
} else {
|
||||
base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finalProviders := make(map[provider.InferenceProvider]ProviderConfig)
|
||||
for providerName, providerConfig := range base.Providers {
|
||||
err := validateProvider(providerName, providerConfig)
|
||||
if err != nil {
|
||||
logging.Warn("Skipping provider", "name", providerName, "error", err)
|
||||
}
|
||||
finalProviders[providerName] = providerConfig
|
||||
}
|
||||
base.Providers = finalProviders
|
||||
}
|
||||
|
||||
func providerDefaultConfig(providerId provider.InferenceProvider) ProviderConfig {
|
||||
switch providerId {
|
||||
case provider.InferenceProviderAnthropic:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeAnthropic,
|
||||
}
|
||||
case provider.InferenceProviderOpenAI:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
}
|
||||
case provider.InferenceProviderGemini:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeGemini,
|
||||
}
|
||||
case provider.InferenceProviderBedrock:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeBedrock,
|
||||
}
|
||||
case provider.InferenceProviderAzure:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeAzure,
|
||||
}
|
||||
case provider.InferenceProviderOpenRouter:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
ExtraHeaders: map[string]string{
|
||||
"HTTP-Referer": "crush.charm.land",
|
||||
"X-Title": "Crush",
|
||||
},
|
||||
}
|
||||
case provider.InferenceProviderXAI:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeXAI,
|
||||
BaseURL: "https://api.x.ai/v1",
|
||||
}
|
||||
case provider.InferenceProviderVertexAI:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeVertexAI,
|
||||
}
|
||||
default:
|
||||
return ProviderConfig{
|
||||
ID: providerId,
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func defaultConfigBasedOnEnv() *Config {
|
||||
cfg := &Config{
|
||||
Options: Options{
|
||||
DataDirectory: defaultDataDirectory,
|
||||
ContextPaths: defaultContextPaths,
|
||||
},
|
||||
Providers: make(map[provider.InferenceProvider]ProviderConfig),
|
||||
}
|
||||
|
||||
providers := Providers()
|
||||
|
||||
for _, p := range providers {
|
||||
if strings.HasPrefix(p.APIKey, "$") {
|
||||
envVar := strings.TrimPrefix(p.APIKey, "$")
|
||||
if apiKey := os.Getenv(envVar); apiKey != "" {
|
||||
providerConfig := providerDefaultConfig(p.ID)
|
||||
providerConfig.APIKey = apiKey
|
||||
providerConfig.DefaultLargeModel = p.DefaultLargeModelID
|
||||
providerConfig.DefaultSmallModel = p.DefaultSmallModelID
|
||||
for _, model := range p.Models {
|
||||
providerConfig.Models = append(providerConfig.Models, Model{
|
||||
ID: model.ID,
|
||||
Name: model.Name,
|
||||
CostPer1MIn: model.CostPer1MIn,
|
||||
CostPer1MOut: model.CostPer1MOut,
|
||||
CostPer1MInCached: model.CostPer1MInCached,
|
||||
CostPer1MOutCached: model.CostPer1MOutCached,
|
||||
ContextWindow: model.ContextWindow,
|
||||
DefaultMaxTokens: model.DefaultMaxTokens,
|
||||
CanReason: model.CanReason,
|
||||
SupportsImages: model.SupportsImages,
|
||||
})
|
||||
}
|
||||
cfg.Providers[p.ID] = providerConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: support local models
|
||||
|
||||
if useVertexAI := os.Getenv("GOOGLE_GENAI_USE_VERTEXAI"); useVertexAI == "true" {
|
||||
providerConfig := providerDefaultConfig(provider.InferenceProviderVertexAI)
|
||||
providerConfig.ExtraParams = map[string]string{
|
||||
"project": os.Getenv("GOOGLE_CLOUD_PROJECT"),
|
||||
"location": os.Getenv("GOOGLE_CLOUD_LOCATION"),
|
||||
}
|
||||
cfg.Providers[provider.InferenceProviderVertexAI] = providerConfig
|
||||
}
|
||||
|
||||
if hasAWSCredentials() {
|
||||
providerConfig := providerDefaultConfig(provider.InferenceProviderBedrock)
|
||||
cfg.Providers[provider.InferenceProviderBedrock] = providerConfig
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func hasAWSCredentials() bool {
|
||||
if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
|
||||
os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func WorkingDirectory() string {
|
||||
return cwd
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -16,8 +15,8 @@ import (
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func Connect(ctx context.Context) (*sql.DB, error) {
|
||||
dataDir := config.Get().Data.Directory
|
||||
func Connect() (*sql.DB, error) {
|
||||
dataDir := config.Get().Options.DataDirectory
|
||||
if dataDir == "" {
|
||||
return nil, fmt.Errorf("data.dir is not set")
|
||||
}
|
||||
|
||||
@@ -17,12 +17,13 @@ INSERT INTO messages (
|
||||
role,
|
||||
parts,
|
||||
model,
|
||||
provider,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
)
|
||||
RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
|
||||
`
|
||||
|
||||
type CreateMessageParams struct {
|
||||
@@ -31,6 +32,7 @@ type CreateMessageParams struct {
|
||||
Role string `json:"role"`
|
||||
Parts string `json:"parts"`
|
||||
Model sql.NullString `json:"model"`
|
||||
Provider sql.NullString `json:"provider"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
|
||||
@@ -40,6 +42,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
|
||||
arg.Role,
|
||||
arg.Parts,
|
||||
arg.Model,
|
||||
arg.Provider,
|
||||
)
|
||||
var i Message
|
||||
err := row.Scan(
|
||||
@@ -51,6 +54,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
&i.Provider,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -76,7 +80,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e
|
||||
}
|
||||
|
||||
const getMessage = `-- name: GetMessage :one
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
|
||||
FROM messages
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
@@ -93,12 +97,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
&i.Provider,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listMessagesBySession = `-- name: ListMessagesBySession :many
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC
|
||||
@@ -122,6 +127,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) (
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
&i.Provider,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
-- Add provider column to messages table
|
||||
ALTER TABLE messages ADD COLUMN provider TEXT;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
-- Remove provider column from messages table
|
||||
ALTER TABLE messages DROP COLUMN provider;
|
||||
-- +goose StatementEnd
|
||||
@@ -27,6 +27,7 @@ type Message struct {
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
FinishedAt sql.NullInt64 `json:"finished_at"`
|
||||
Provider sql.NullString `json:"provider"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
|
||||
@@ -16,10 +16,11 @@ INSERT INTO messages (
|
||||
role,
|
||||
parts,
|
||||
model,
|
||||
provider,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now')
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
const defaultURL = "http://localhost:8080"
|
||||
const defaultURL = "https://fur.charmcli.dev"
|
||||
|
||||
// Client represents a client for the fur service.
|
||||
type Client struct {
|
||||
|
||||
@@ -6,14 +6,13 @@ type Type string
|
||||
|
||||
// All the supported AI provider types.
|
||||
const (
|
||||
TypeOpenAI Type = "openai"
|
||||
TypeAnthropic Type = "anthropic"
|
||||
TypeGemini Type = "gemini"
|
||||
TypeAzure Type = "azure"
|
||||
TypeBedrock Type = "bedrock"
|
||||
TypeVertexAI Type = "vertexai"
|
||||
TypeXAI Type = "xai"
|
||||
TypeOpenRouter Type = "openrouter"
|
||||
TypeOpenAI Type = "openai"
|
||||
TypeAnthropic Type = "anthropic"
|
||||
TypeGemini Type = "gemini"
|
||||
TypeAzure Type = "azure"
|
||||
TypeBedrock Type = "bedrock"
|
||||
TypeVertexAI Type = "vertexai"
|
||||
TypeXAI Type = "xai"
|
||||
)
|
||||
|
||||
// InferenceProvider represents the inference provider identifier.
|
||||
|
||||
@@ -5,17 +5,15 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/session"
|
||||
)
|
||||
|
||||
type agentTool struct {
|
||||
sessions session.Service
|
||||
messages message.Service
|
||||
lspClients map[string]*lsp.Client
|
||||
agent Service
|
||||
sessions session.Service
|
||||
messages message.Service
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -58,17 +56,12 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
|
||||
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
|
||||
}
|
||||
|
||||
agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
|
||||
}
|
||||
|
||||
session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
|
||||
done, err := agent.Run(ctx, session.ID, params.Prompt)
|
||||
done, err := b.agent.Run(ctx, session.ID, params.Prompt)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
|
||||
}
|
||||
@@ -101,13 +94,13 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
|
||||
}
|
||||
|
||||
func NewAgentTool(
|
||||
Sessions session.Service,
|
||||
Messages message.Service,
|
||||
LspClients map[string]*lsp.Client,
|
||||
agent Service,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
) tools.BaseTool {
|
||||
return &agentTool{
|
||||
sessions: Sessions,
|
||||
messages: Messages,
|
||||
lspClients: LspClients,
|
||||
sessions: sessions,
|
||||
messages: messages,
|
||||
agent: agent,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,18 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
configv2 "github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/llm/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
@@ -47,71 +49,198 @@ type AgentEvent struct {
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[AgentEvent]
|
||||
Model() models.Model
|
||||
Model() configv2.Model
|
||||
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
|
||||
Cancel(sessionID string)
|
||||
CancelAll()
|
||||
IsSessionBusy(sessionID string) bool
|
||||
IsBusy() bool
|
||||
Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
|
||||
Update(model configv2.PreferredModel) (configv2.Model, error)
|
||||
Summarize(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type agent struct {
|
||||
*pubsub.Broker[AgentEvent]
|
||||
agentCfg configv2.Agent
|
||||
sessions session.Service
|
||||
messages message.Service
|
||||
|
||||
tools []tools.BaseTool
|
||||
provider provider.Provider
|
||||
tools []tools.BaseTool
|
||||
provider provider.Provider
|
||||
providerID string
|
||||
|
||||
titleProvider provider.Provider
|
||||
summarizeProvider provider.Provider
|
||||
titleProvider provider.Provider
|
||||
summarizeProvider provider.Provider
|
||||
summarizeProviderID string
|
||||
|
||||
activeRequests sync.Map
|
||||
}
|
||||
|
||||
var agentPromptMap = map[configv2.AgentID]prompt.PromptID{
|
||||
configv2.AgentCoder: prompt.PromptCoder,
|
||||
configv2.AgentTask: prompt.PromptTask,
|
||||
}
|
||||
|
||||
func NewAgent(
|
||||
agentName config.AgentName,
|
||||
agentCfg configv2.Agent,
|
||||
// These services are needed in the tools
|
||||
permissions permission.Service,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
agentTools []tools.BaseTool,
|
||||
history history.Service,
|
||||
lspClients map[string]*lsp.Client,
|
||||
) (Service, error) {
|
||||
agentProvider, err := createAgentProvider(agentName)
|
||||
ctx := context.Background()
|
||||
cfg := configv2.Get()
|
||||
otherTools := GetMcpTools(ctx, permissions)
|
||||
if len(lspClients) > 0 {
|
||||
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
|
||||
}
|
||||
|
||||
allTools := []tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
tools.NewEditTool(lspClients, permissions, history),
|
||||
tools.NewFetchTool(permissions),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
}
|
||||
|
||||
if agentCfg.ID == configv2.AgentCoder {
|
||||
taskAgentCfg := configv2.Get().Agents[configv2.AgentTask]
|
||||
if taskAgentCfg.ID == "" {
|
||||
return nil, fmt.Errorf("task agent not found in config")
|
||||
}
|
||||
taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create task agent: %w", err)
|
||||
}
|
||||
|
||||
allTools = append(
|
||||
allTools,
|
||||
NewAgentTool(
|
||||
taskAgent,
|
||||
sessions,
|
||||
messages,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
allTools = append(allTools, otherTools...)
|
||||
var providerCfg configv2.ProviderConfig
|
||||
for _, p := range cfg.Providers {
|
||||
if p.ID == agentCfg.Provider {
|
||||
providerCfg = p
|
||||
break
|
||||
}
|
||||
}
|
||||
if providerCfg.ID == "" {
|
||||
return nil, fmt.Errorf("provider %s not found in config", agentCfg.Provider)
|
||||
}
|
||||
|
||||
var model configv2.Model
|
||||
for _, m := range providerCfg.Models {
|
||||
if m.ID == agentCfg.Model {
|
||||
model = m
|
||||
break
|
||||
}
|
||||
}
|
||||
if model.ID == "" {
|
||||
return nil, fmt.Errorf("model %s not found in provider %s", agentCfg.Model, agentCfg.Provider)
|
||||
}
|
||||
|
||||
promptID := agentPromptMap[agentCfg.ID]
|
||||
if promptID == "" {
|
||||
promptID = prompt.PromptDefault
|
||||
}
|
||||
opts := []provider.ProviderClientOption{
|
||||
provider.WithModel(model),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
|
||||
provider.WithMaxTokens(model.DefaultMaxTokens),
|
||||
}
|
||||
agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var titleProvider provider.Provider
|
||||
// Only generate titles for the coder agent
|
||||
if agentName == config.AgentCoder {
|
||||
titleProvider, err = createAgentProvider(config.AgentTitle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
smallModelCfg := cfg.Models.Small
|
||||
var smallModel configv2.Model
|
||||
|
||||
var smallModelProviderCfg configv2.ProviderConfig
|
||||
if smallModelCfg.Provider == providerCfg.ID {
|
||||
smallModelProviderCfg = providerCfg
|
||||
} else {
|
||||
for _, p := range cfg.Providers {
|
||||
if p.ID == smallModelCfg.Provider {
|
||||
smallModelProviderCfg = p
|
||||
break
|
||||
}
|
||||
}
|
||||
if smallModelProviderCfg.ID == "" {
|
||||
return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
|
||||
}
|
||||
}
|
||||
var summarizeProvider provider.Provider
|
||||
if agentName == config.AgentCoder {
|
||||
summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
for _, m := range smallModelProviderCfg.Models {
|
||||
if m.ID == smallModelCfg.ModelID {
|
||||
smallModel = m
|
||||
break
|
||||
}
|
||||
}
|
||||
if smallModel.ID == "" {
|
||||
return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
|
||||
}
|
||||
|
||||
titleOpts := []provider.ProviderClientOption{
|
||||
provider.WithModel(smallModel),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
|
||||
provider.WithMaxTokens(40),
|
||||
}
|
||||
titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summarizeOpts := []provider.ProviderClientOption{
|
||||
provider.WithModel(smallModel),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
|
||||
provider.WithMaxTokens(smallModel.DefaultMaxTokens),
|
||||
}
|
||||
summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentTools := []tools.BaseTool{}
|
||||
if agentCfg.AllowedTools == nil {
|
||||
agentTools = allTools
|
||||
} else {
|
||||
for _, tool := range allTools {
|
||||
if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
|
||||
agentTools = append(agentTools, tool)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
agent := &agent{
|
||||
Broker: pubsub.NewBroker[AgentEvent](),
|
||||
provider: agentProvider,
|
||||
messages: messages,
|
||||
sessions: sessions,
|
||||
tools: agentTools,
|
||||
titleProvider: titleProvider,
|
||||
summarizeProvider: summarizeProvider,
|
||||
activeRequests: sync.Map{},
|
||||
Broker: pubsub.NewBroker[AgentEvent](),
|
||||
agentCfg: agentCfg,
|
||||
provider: agentProvider,
|
||||
providerID: string(providerCfg.ID),
|
||||
messages: messages,
|
||||
sessions: sessions,
|
||||
tools: agentTools,
|
||||
titleProvider: titleProvider,
|
||||
summarizeProvider: summarizeProvider,
|
||||
summarizeProviderID: string(smallModelProviderCfg.ID),
|
||||
activeRequests: sync.Map{},
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (a *agent) Model() models.Model {
|
||||
func (a *agent) Model() configv2.Model {
|
||||
return a.provider.Model()
|
||||
}
|
||||
|
||||
@@ -207,7 +336,7 @@ func (a *agent) err(err error) AgentEvent {
|
||||
}
|
||||
|
||||
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
|
||||
if !a.provider.Model().SupportsAttachments && attachments != nil {
|
||||
if !a.provider.Model().SupportsImages && attachments != nil {
|
||||
attachments = nil
|
||||
}
|
||||
events := make(chan AgentEvent)
|
||||
@@ -327,9 +456,10 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
|
||||
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
|
||||
|
||||
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{},
|
||||
Model: a.provider.Model().ID,
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{},
|
||||
Model: a.provider.Model().ID,
|
||||
Provider: a.providerID,
|
||||
})
|
||||
if err != nil {
|
||||
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
|
||||
@@ -424,8 +554,9 @@ out:
|
||||
parts = append(parts, tr)
|
||||
}
|
||||
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
|
||||
Role: message.Tool,
|
||||
Parts: parts,
|
||||
Role: message.Tool,
|
||||
Parts: parts,
|
||||
Provider: a.providerID,
|
||||
})
|
||||
if err != nil {
|
||||
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
|
||||
@@ -484,7 +615,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model configv2.Model, usage provider.TokenUsage) error {
|
||||
sess, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
@@ -506,21 +637,48 @@ func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
|
||||
func (a *agent) Update(modelCfg configv2.PreferredModel) (configv2.Model, error) {
|
||||
if a.IsBusy() {
|
||||
return models.Model{}, fmt.Errorf("cannot change model while processing requests")
|
||||
return configv2.Model{}, fmt.Errorf("cannot change model while processing requests")
|
||||
}
|
||||
|
||||
if err := config.UpdateAgentModel(agentName, modelID); err != nil {
|
||||
return models.Model{}, fmt.Errorf("failed to update config: %w", err)
|
||||
cfg := configv2.Get()
|
||||
var providerCfg configv2.ProviderConfig
|
||||
for _, p := range cfg.Providers {
|
||||
if p.ID == modelCfg.Provider {
|
||||
providerCfg = p
|
||||
break
|
||||
}
|
||||
}
|
||||
if providerCfg.ID == "" {
|
||||
return configv2.Model{}, fmt.Errorf("provider %s not found in config", modelCfg.Provider)
|
||||
}
|
||||
|
||||
provider, err := createAgentProvider(agentName)
|
||||
var model configv2.Model
|
||||
for _, m := range providerCfg.Models {
|
||||
if m.ID == modelCfg.ModelID {
|
||||
model = m
|
||||
break
|
||||
}
|
||||
}
|
||||
if model.ID == "" {
|
||||
return configv2.Model{}, fmt.Errorf("model %s not found in provider %s", modelCfg.ModelID, modelCfg.Provider)
|
||||
}
|
||||
|
||||
promptID := agentPromptMap[a.agentCfg.ID]
|
||||
if promptID == "" {
|
||||
promptID = prompt.PromptDefault
|
||||
}
|
||||
opts := []provider.ProviderClientOption{
|
||||
provider.WithModel(model),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
|
||||
provider.WithMaxTokens(model.DefaultMaxTokens),
|
||||
}
|
||||
agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
|
||||
if err != nil {
|
||||
return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
|
||||
return configv2.Model{}, err
|
||||
}
|
||||
|
||||
a.provider = provider
|
||||
a.provider = agentProvider
|
||||
|
||||
return a.provider.Model(), nil
|
||||
}
|
||||
@@ -654,7 +812,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
|
||||
Time: time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
Model: a.summarizeProvider.Model().ID,
|
||||
Model: a.summarizeProvider.Model().ID,
|
||||
Provider: a.summarizeProviderID,
|
||||
})
|
||||
if err != nil {
|
||||
event = AgentEvent{
|
||||
@@ -705,51 +864,3 @@ func (a *agent) CancelAll() {
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
|
||||
cfg := config.Get()
|
||||
agentConfig, ok := cfg.Agents[agentName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("agent %s not found", agentName)
|
||||
}
|
||||
model, ok := models.SupportedModels[agentConfig.Model]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
|
||||
}
|
||||
|
||||
providerCfg, ok := cfg.Providers[model.Provider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider %s not supported", model.Provider)
|
||||
}
|
||||
if providerCfg.Disabled {
|
||||
return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
|
||||
}
|
||||
maxTokens := model.DefaultMaxTokens
|
||||
if agentConfig.MaxTokens > 0 {
|
||||
maxTokens = agentConfig.MaxTokens
|
||||
}
|
||||
opts := []provider.ProviderClientOption{
|
||||
provider.WithAPIKey(providerCfg.APIKey),
|
||||
provider.WithModel(model),
|
||||
provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
|
||||
provider.WithMaxTokens(maxTokens),
|
||||
}
|
||||
// TODO: reimplement
|
||||
// if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
|
||||
// opts = append(
|
||||
// opts,
|
||||
// provider.WithOpenAIOptions(
|
||||
// provider.WithReasoningEffort(agentConfig.ReasoningEffort),
|
||||
// ),
|
||||
// )
|
||||
// }
|
||||
agentProvider, err := provider.NewProvider(
|
||||
model.Provider,
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create provider: %v", err)
|
||||
}
|
||||
|
||||
return agentProvider, nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
type mcpTool struct {
|
||||
mcpName string
|
||||
tool mcp.Tool
|
||||
mcpConfig config.MCPServer
|
||||
mcpConfig config.MCP
|
||||
permissions permission.Service
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
|
||||
return tools.NewTextErrorResponse("invalid mcp type"), nil
|
||||
}
|
||||
|
||||
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
|
||||
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
|
||||
return &mcpTool{
|
||||
mcpName: name,
|
||||
tool: tool,
|
||||
@@ -139,7 +139,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
|
||||
|
||||
var mcpTools []tools.BaseTool
|
||||
|
||||
func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
|
||||
func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
|
||||
var stdioTools []tools.BaseTool
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
@@ -170,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
|
||||
if len(mcpTools) > 0 {
|
||||
return mcpTools
|
||||
}
|
||||
for name, m := range config.Get().MCPServers {
|
||||
for name, m := range config.Get().MCP {
|
||||
switch m.Type {
|
||||
case config.MCPStdio:
|
||||
c, err := client.NewStdioMCPClient(
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
"github.com/charmbracelet/crush/internal/session"
|
||||
)
|
||||
|
||||
func CoderAgentTools(
|
||||
permissions permission.Service,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
history history.Service,
|
||||
lspClients map[string]*lsp.Client,
|
||||
) []tools.BaseTool {
|
||||
ctx := context.Background()
|
||||
otherTools := GetMcpTools(ctx, permissions)
|
||||
if len(lspClients) > 0 {
|
||||
otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
|
||||
}
|
||||
return append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
tools.NewEditTool(lspClients, permissions, history),
|
||||
tools.NewFetchTool(permissions),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
NewAgentTool(sessions, messages, lspClients),
|
||||
}, otherTools...,
|
||||
)
|
||||
}
|
||||
|
||||
func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
|
||||
return []tools.BaseTool{
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewSourcegraphTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
}
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderAnthropic InferenceProvider = "anthropic"
|
||||
|
||||
// Models
|
||||
Claude35Sonnet ModelID = "claude-3.5-sonnet"
|
||||
Claude3Haiku ModelID = "claude-3-haiku"
|
||||
Claude37Sonnet ModelID = "claude-3.7-sonnet"
|
||||
Claude35Haiku ModelID = "claude-3.5-haiku"
|
||||
Claude3Opus ModelID = "claude-3-opus"
|
||||
Claude4Opus ModelID = "claude-4-opus"
|
||||
Claude4Sonnet ModelID = "claude-4-sonnet"
|
||||
)
|
||||
|
||||
// https://docs.anthropic.com/en/docs/about-claude/models/all-models
|
||||
var AnthropicModels = map[ModelID]Model{
|
||||
Claude35Sonnet: {
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 5000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude3Haiku: {
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
|
||||
CostPer1MIn: 0.25,
|
||||
CostPer1MInCached: 0.30,
|
||||
CostPer1MOutCached: 0.03,
|
||||
CostPer1MOut: 1.25,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude37Sonnet: {
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude35Haiku: {
|
||||
ID: Claude35Haiku,
|
||||
Name: "Claude 3.5 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-haiku-latest",
|
||||
CostPer1MIn: 0.80,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
CostPer1MOut: 4.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude3Opus: {
|
||||
ID: Claude3Opus,
|
||||
Name: "Claude 3 Opus",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-opus-latest",
|
||||
CostPer1MIn: 15.0,
|
||||
CostPer1MInCached: 18.75,
|
||||
CostPer1MOutCached: 1.50,
|
||||
CostPer1MOut: 75.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude4Sonnet: {
|
||||
ID: Claude4Sonnet,
|
||||
Name: "Claude 4 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-sonnet-4-20250514",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude4Opus: {
|
||||
ID: Claude4Opus,
|
||||
Name: "Claude 4 Opus",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-opus-4-20250514",
|
||||
CostPer1MIn: 15.0,
|
||||
CostPer1MInCached: 18.75,
|
||||
CostPer1MOutCached: 1.50,
|
||||
CostPer1MOut: 75.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package models
|
||||
|
||||
const ProviderAzure InferenceProvider = "azure"
|
||||
|
||||
const (
|
||||
AzureGPT41 ModelID = "azure.gpt-4.1"
|
||||
AzureGPT41Mini ModelID = "azure.gpt-4.1-mini"
|
||||
AzureGPT41Nano ModelID = "azure.gpt-4.1-nano"
|
||||
AzureGPT45Preview ModelID = "azure.gpt-4.5-preview"
|
||||
AzureGPT4o ModelID = "azure.gpt-4o"
|
||||
AzureGPT4oMini ModelID = "azure.gpt-4o-mini"
|
||||
AzureO1 ModelID = "azure.o1"
|
||||
AzureO1Mini ModelID = "azure.o1-mini"
|
||||
AzureO3 ModelID = "azure.o3"
|
||||
AzureO3Mini ModelID = "azure.o3-mini"
|
||||
AzureO4Mini ModelID = "azure.o4-mini"
|
||||
)
|
||||
|
||||
var AzureModels = map[ModelID]Model{
|
||||
AzureGPT41: {
|
||||
ID: AzureGPT41,
|
||||
Name: "Azure OpenAI – GPT 4.1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT41Mini: {
|
||||
ID: AzureGPT41Mini,
|
||||
Name: "Azure OpenAI – GPT 4.1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT41Nano: {
|
||||
ID: AzureGPT41Nano,
|
||||
Name: "Azure OpenAI – GPT 4.1 nano",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT45Preview: {
|
||||
ID: AzureGPT45Preview,
|
||||
Name: "Azure OpenAI – GPT 4.5 preview",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT4o: {
|
||||
ID: AzureGPT4o,
|
||||
Name: "Azure OpenAI – GPT-4o",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT4oMini: {
|
||||
ID: AzureGPT4oMini,
|
||||
Name: "Azure OpenAI – GPT-4o mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO1: {
|
||||
ID: AzureO1,
|
||||
Name: "Azure OpenAI – O1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO1Mini: {
|
||||
ID: AzureO1Mini,
|
||||
Name: "Azure OpenAI – O1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Mini].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO3: {
|
||||
ID: AzureO3,
|
||||
Name: "Azure OpenAI – O3",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO3Mini: {
|
||||
ID: AzureO3Mini,
|
||||
Name: "Azure OpenAI – O3 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3Mini].CanReason,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
AzureO4Mini: {
|
||||
ID: AzureO4Mini,
|
||||
Name: "Azure OpenAI – O4 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O4Mini].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderGemini InferenceProvider = "gemini"
|
||||
|
||||
// Models
|
||||
Gemini25Flash ModelID = "gemini-2.5-flash"
|
||||
Gemini25 ModelID = "gemini-2.5"
|
||||
Gemini20Flash ModelID = "gemini-2.0-flash"
|
||||
Gemini20FlashLite ModelID = "gemini-2.0-flash-lite"
|
||||
)
|
||||
|
||||
var GeminiModels = map[ModelID]Model{
|
||||
Gemini25Flash: {
|
||||
ID: Gemini25Flash,
|
||||
Name: "Gemini 2.5 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-flash-preview-04-17",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Gemini25: {
|
||||
ID: Gemini25,
|
||||
Name: "Gemini 2.5 Pro",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-pro-preview-05-06",
|
||||
CostPer1MIn: 1.25,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 10,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Gemini20Flash: {
|
||||
ID: Gemini20Flash,
|
||||
Name: "Gemini 2.0 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Gemini20FlashLite: {
|
||||
ID: Gemini20FlashLite,
|
||||
Name: "Gemini 2.0 Flash Lite",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash-lite",
|
||||
CostPer1MIn: 0.05,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.30,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderGROQ InferenceProvider = "groq"
|
||||
|
||||
// GROQ
|
||||
QWENQwq ModelID = "qwen-qwq"
|
||||
|
||||
// GROQ preview models
|
||||
Llama4Scout ModelID = "meta-llama/llama-4-scout-17b-16e-instruct"
|
||||
Llama4Maverick ModelID = "meta-llama/llama-4-maverick-17b-128e-instruct"
|
||||
Llama3_3_70BVersatile ModelID = "llama-3.3-70b-versatile"
|
||||
DeepseekR1DistillLlama70b ModelID = "deepseek-r1-distill-llama-70b"
|
||||
)
|
||||
|
||||
var GroqModels = map[ModelID]Model{
|
||||
//
|
||||
// GROQ
|
||||
QWENQwq: {
|
||||
ID: QWENQwq,
|
||||
Name: "Qwen Qwq",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "qwen-qwq-32b",
|
||||
CostPer1MIn: 0.29,
|
||||
CostPer1MInCached: 0.275,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.39,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
// for some reason, the groq api doesn't like the reasoningEffort parameter
|
||||
CanReason: false,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
|
||||
Llama4Scout: {
|
||||
ID: Llama4Scout,
|
||||
Name: "Llama4Scout",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CostPer1MIn: 0.11,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.34,
|
||||
ContextWindow: 128_000, // 10M when?
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Llama4Maverick: {
|
||||
ID: Llama4Maverick,
|
||||
Name: "Llama4Maverick",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CostPer1MIn: 0.20,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.20,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Llama3_3_70BVersatile: {
|
||||
ID: Llama3_3_70BVersatile,
|
||||
Name: "Llama3_3_70BVersatile",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "llama-3.3-70b-versatile",
|
||||
CostPer1MIn: 0.59,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.79,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
|
||||
DeepseekR1DistillLlama70b: {
|
||||
ID: DeepseekR1DistillLlama70b,
|
||||
Name: "DeepseekR1DistillLlama70b",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "deepseek-r1-distill-llama-70b",
|
||||
CostPer1MIn: 0.75,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.99,
|
||||
ContextWindow: 128_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
}
|
||||
@@ -1,206 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderLocal InferenceProvider = "local"
|
||||
|
||||
localModelsPath = "v1/models"
|
||||
lmStudioBetaModelsPath = "api/v0/models"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
|
||||
localEndpoint, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
logging.Debug("Failed to parse local endpoint",
|
||||
"error", err,
|
||||
"endpoint", endpoint,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
load := func(url *url.URL, path string) []localModel {
|
||||
url.Path = path
|
||||
return listLocalModels(url.String())
|
||||
}
|
||||
|
||||
models := load(localEndpoint, lmStudioBetaModelsPath)
|
||||
|
||||
if len(models) == 0 {
|
||||
models = load(localEndpoint, localModelsPath)
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
logging.Debug("No local models found",
|
||||
"endpoint", endpoint,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
loadLocalModels(models)
|
||||
|
||||
viper.SetDefault("providers.local.apiKey", "dummy")
|
||||
}
|
||||
}
|
||||
|
||||
type localModelList struct {
|
||||
Data []localModel `json:"data"`
|
||||
}
|
||||
|
||||
type localModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Type string `json:"type"`
|
||||
Publisher string `json:"publisher"`
|
||||
Arch string `json:"arch"`
|
||||
CompatibilityType string `json:"compatibility_type"`
|
||||
Quantization string `json:"quantization"`
|
||||
State string `json:"state"`
|
||||
MaxContextLength int64 `json:"max_context_length"`
|
||||
LoadedContextLength int64 `json:"loaded_context_length"`
|
||||
}
|
||||
|
||||
func listLocalModels(modelsEndpoint string) []localModel {
|
||||
res, err := http.NewRequestWithContext(context.Background(), http.MethodGet, modelsEndpoint, nil)
|
||||
if err != nil {
|
||||
logging.Debug("Failed to list local models",
|
||||
"error", err,
|
||||
"endpoint", modelsEndpoint,
|
||||
)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.Response.StatusCode != http.StatusOK {
|
||||
logging.Debug("Failed to list local models",
|
||||
"status", res.Response.Status,
|
||||
"endpoint", modelsEndpoint,
|
||||
)
|
||||
}
|
||||
|
||||
var modelList localModelList
|
||||
if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
|
||||
logging.Debug("Failed to list local models",
|
||||
"error", err,
|
||||
"endpoint", modelsEndpoint,
|
||||
)
|
||||
}
|
||||
|
||||
var supportedModels []localModel
|
||||
for _, model := range modelList.Data {
|
||||
if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
|
||||
if model.Object != "model" || model.Type != "llm" {
|
||||
logging.Debug("Skipping unsupported LMStudio model",
|
||||
"endpoint", modelsEndpoint,
|
||||
"id", model.ID,
|
||||
"object", model.Object,
|
||||
"type", model.Type,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
supportedModels = append(supportedModels, model)
|
||||
}
|
||||
|
||||
return supportedModels
|
||||
}
|
||||
|
||||
func loadLocalModels(models []localModel) {
|
||||
for i, m := range models {
|
||||
model := convertLocalModel(m)
|
||||
SupportedModels[model.ID] = model
|
||||
|
||||
if i == 0 || m.State == "loaded" {
|
||||
viper.SetDefault("agents.coder.model", model.ID)
|
||||
viper.SetDefault("agents.summarizer.model", model.ID)
|
||||
viper.SetDefault("agents.task.model", model.ID)
|
||||
viper.SetDefault("agents.title.model", model.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convertLocalModel(model localModel) Model {
|
||||
return Model{
|
||||
ID: ModelID("local." + model.ID),
|
||||
Name: friendlyModelName(model.ID),
|
||||
Provider: ProviderLocal,
|
||||
APIModel: model.ID,
|
||||
ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
|
||||
DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
}
|
||||
}
|
||||
|
||||
var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
|
||||
|
||||
func friendlyModelName(modelID string) string {
|
||||
mainID := modelID
|
||||
tag := ""
|
||||
|
||||
if slash := strings.LastIndex(mainID, "/"); slash != -1 {
|
||||
mainID = mainID[slash+1:]
|
||||
}
|
||||
|
||||
if at := strings.Index(modelID, "@"); at != -1 {
|
||||
mainID = modelID[:at]
|
||||
tag = modelID[at+1:]
|
||||
}
|
||||
|
||||
match := modelInfoRegex.FindStringSubmatch(mainID)
|
||||
if match == nil {
|
||||
return modelID
|
||||
}
|
||||
|
||||
capitalize := func(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(s)
|
||||
runes[0] = unicode.ToUpper(runes[0])
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
family := capitalize(match[1])
|
||||
version := ""
|
||||
label := ""
|
||||
|
||||
if len(match) > 2 && match[2] != "" {
|
||||
version = strings.ToUpper(match[2])
|
||||
}
|
||||
|
||||
if len(match) > 3 && match[3] != "" {
|
||||
label = capitalize(match[3])
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if family != "" {
|
||||
parts = append(parts, family)
|
||||
}
|
||||
if version != "" {
|
||||
parts = append(parts, version)
|
||||
}
|
||||
if label != "" {
|
||||
parts = append(parts, label)
|
||||
}
|
||||
if tag != "" {
|
||||
parts = append(parts, tag)
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package models
|
||||
|
||||
import "maps"
|
||||
|
||||
type (
|
||||
ModelID string
|
||||
InferenceProvider string
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider InferenceProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
DefaultMaxTokens int64 `json:"default_max_tokens"`
|
||||
CanReason bool `json:"can_reason"`
|
||||
SupportsAttachments bool `json:"supports_attachments"`
|
||||
}
|
||||
|
||||
// Model IDs
|
||||
const ( // GEMINI
|
||||
// Bedrock
|
||||
BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderBedrock InferenceProvider = "bedrock"
|
||||
// ForTests
|
||||
ProviderMock InferenceProvider = "__mock"
|
||||
)
|
||||
|
||||
var SupportedModels = map[ModelID]Model{
|
||||
// Bedrock
|
||||
BedrockClaude37Sonnet: {
|
||||
ID: BedrockClaude37Sonnet,
|
||||
Name: "Bedrock: Claude 3.7 Sonnet",
|
||||
Provider: ProviderBedrock,
|
||||
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
},
|
||||
}
|
||||
|
||||
var KnownProviders = []InferenceProvider{
|
||||
ProviderAnthropic,
|
||||
ProviderOpenAI,
|
||||
ProviderGemini,
|
||||
ProviderAzure,
|
||||
ProviderGROQ,
|
||||
ProviderLocal,
|
||||
ProviderOpenRouter,
|
||||
ProviderVertexAI,
|
||||
ProviderBedrock,
|
||||
ProviderXAI,
|
||||
ProviderMock,
|
||||
}
|
||||
|
||||
func init() {
|
||||
maps.Copy(SupportedModels, AnthropicModels)
|
||||
maps.Copy(SupportedModels, OpenAIModels)
|
||||
maps.Copy(SupportedModels, GeminiModels)
|
||||
maps.Copy(SupportedModels, GroqModels)
|
||||
maps.Copy(SupportedModels, AzureModels)
|
||||
maps.Copy(SupportedModels, OpenRouterModels)
|
||||
maps.Copy(SupportedModels, XAIModels)
|
||||
maps.Copy(SupportedModels, VertexAIGeminiModels)
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderOpenAI InferenceProvider = "openai"
|
||||
|
||||
GPT41 ModelID = "gpt-4.1"
|
||||
GPT41Mini ModelID = "gpt-4.1-mini"
|
||||
GPT41Nano ModelID = "gpt-4.1-nano"
|
||||
GPT45Preview ModelID = "gpt-4.5-preview"
|
||||
GPT4o ModelID = "gpt-4o"
|
||||
GPT4oMini ModelID = "gpt-4o-mini"
|
||||
O1 ModelID = "o1"
|
||||
O1Pro ModelID = "o1-pro"
|
||||
O1Mini ModelID = "o1-mini"
|
||||
O3 ModelID = "o3"
|
||||
O3Mini ModelID = "o3-mini"
|
||||
O4Mini ModelID = "o4-mini"
|
||||
)
|
||||
|
||||
var OpenAIModels = map[ModelID]Model{
|
||||
GPT41: {
|
||||
ID: GPT41,
|
||||
Name: "GPT 4.1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 2.00,
|
||||
CostPer1MInCached: 0.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 8.00,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT41Mini: {
|
||||
ID: GPT41Mini,
|
||||
Name: "GPT 4.1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 0.40,
|
||||
CostPer1MInCached: 0.10,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 1.60,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT41Nano: {
|
||||
ID: GPT41Nano,
|
||||
Name: "GPT 4.1 nano",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0.025,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT45Preview: {
|
||||
ID: GPT45Preview,
|
||||
Name: "GPT 4.5 preview",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: 75.00,
|
||||
CostPer1MInCached: 37.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 150.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 15000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT4o: {
|
||||
ID: GPT4o,
|
||||
Name: "GPT 4o",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: 2.50,
|
||||
CostPer1MInCached: 1.25,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 10.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT4oMini: {
|
||||
ID: GPT4oMini,
|
||||
Name: "GPT 4o mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0.075,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1: {
|
||||
ID: O1,
|
||||
Name: "O1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: 15.00,
|
||||
CostPer1MInCached: 7.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 60.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1Pro: {
|
||||
ID: O1Pro,
|
||||
Name: "o1 pro",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-pro",
|
||||
CostPer1MIn: 150.00,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 600.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1Mini: {
|
||||
ID: O1Mini,
|
||||
Name: "o1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O3: {
|
||||
ID: O3,
|
||||
Name: "o3",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: 10.00,
|
||||
CostPer1MInCached: 2.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 40.00,
|
||||
ContextWindow: 200_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O3Mini: {
|
||||
ID: O3Mini,
|
||||
Name: "o3 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
O4Mini: {
|
||||
ID: O4Mini,
|
||||
Name: "o4 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.275,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderOpenRouter InferenceProvider = "openrouter"
|
||||
|
||||
OpenRouterGPT41 ModelID = "openrouter.gpt-4.1"
|
||||
OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini"
|
||||
OpenRouterGPT41Nano ModelID = "openrouter.gpt-4.1-nano"
|
||||
OpenRouterGPT45Preview ModelID = "openrouter.gpt-4.5-preview"
|
||||
OpenRouterGPT4o ModelID = "openrouter.gpt-4o"
|
||||
OpenRouterGPT4oMini ModelID = "openrouter.gpt-4o-mini"
|
||||
OpenRouterO1 ModelID = "openrouter.o1"
|
||||
OpenRouterO1Pro ModelID = "openrouter.o1-pro"
|
||||
OpenRouterO1Mini ModelID = "openrouter.o1-mini"
|
||||
OpenRouterO3 ModelID = "openrouter.o3"
|
||||
OpenRouterO3Mini ModelID = "openrouter.o3-mini"
|
||||
OpenRouterO4Mini ModelID = "openrouter.o4-mini"
|
||||
OpenRouterGemini25Flash ModelID = "openrouter.gemini-2.5-flash"
|
||||
OpenRouterGemini25 ModelID = "openrouter.gemini-2.5"
|
||||
OpenRouterClaude35Sonnet ModelID = "openrouter.claude-3.5-sonnet"
|
||||
OpenRouterClaude3Haiku ModelID = "openrouter.claude-3-haiku"
|
||||
OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet"
|
||||
OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku"
|
||||
OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus"
|
||||
OpenRouterDeepSeekR1Free ModelID = "openrouter.deepseek-r1-free"
|
||||
)
|
||||
|
||||
var OpenRouterModels = map[ModelID]Model{
|
||||
OpenRouterGPT41: {
|
||||
ID: OpenRouterGPT41,
|
||||
Name: "OpenRouter – GPT 4.1",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1",
|
||||
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT41Mini: {
|
||||
ID: OpenRouterGPT41Mini,
|
||||
Name: "OpenRouter – GPT 4.1 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT41Nano: {
|
||||
ID: OpenRouterGPT41Nano,
|
||||
Name: "OpenRouter – GPT 4.1 nano",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1-nano",
|
||||
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT45Preview: {
|
||||
ID: OpenRouterGPT45Preview,
|
||||
Name: "OpenRouter – GPT 4.5 preview",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.5-preview",
|
||||
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT4o: {
|
||||
ID: OpenRouterGPT4o,
|
||||
Name: "OpenRouter – GPT 4o",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4o",
|
||||
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT4oMini: {
|
||||
ID: OpenRouterGPT4oMini,
|
||||
Name: "OpenRouter – GPT 4o mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4o-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
|
||||
},
|
||||
OpenRouterO1: {
|
||||
ID: OpenRouterO1,
|
||||
Name: "OpenRouter – O1",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1",
|
||||
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1].CanReason,
|
||||
},
|
||||
OpenRouterO1Pro: {
|
||||
ID: OpenRouterO1Pro,
|
||||
Name: "OpenRouter – o1 pro",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1-pro",
|
||||
CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Pro].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Pro].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Pro].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Pro].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Pro].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Pro].CanReason,
|
||||
},
|
||||
OpenRouterO1Mini: {
|
||||
ID: OpenRouterO1Mini,
|
||||
Name: "OpenRouter – o1 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1-mini",
|
||||
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Mini].CanReason,
|
||||
},
|
||||
OpenRouterO3: {
|
||||
ID: OpenRouterO3,
|
||||
Name: "OpenRouter – o3",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o3",
|
||||
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3].CanReason,
|
||||
},
|
||||
OpenRouterO3Mini: {
|
||||
ID: OpenRouterO3Mini,
|
||||
Name: "OpenRouter – o3 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o3-mini-high",
|
||||
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3Mini].CanReason,
|
||||
},
|
||||
OpenRouterO4Mini: {
|
||||
ID: OpenRouterO4Mini,
|
||||
Name: "OpenRouter – o4 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o4-mini-high",
|
||||
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O4Mini].CanReason,
|
||||
},
|
||||
OpenRouterGemini25Flash: {
|
||||
ID: OpenRouterGemini25Flash,
|
||||
Name: "OpenRouter – Gemini 2.5 Flash",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "google/gemini-2.5-flash-preview:thinking",
|
||||
CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
|
||||
CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
|
||||
CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
|
||||
CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
|
||||
ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
|
||||
DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGemini25: {
|
||||
ID: OpenRouterGemini25,
|
||||
Name: "OpenRouter – Gemini 2.5 Pro",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "google/gemini-2.5-pro-preview-03-25",
|
||||
CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
|
||||
CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
|
||||
CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
|
||||
CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
|
||||
ContextWindow: GeminiModels[Gemini25].ContextWindow,
|
||||
DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude35Sonnet: {
|
||||
ID: OpenRouterClaude35Sonnet,
|
||||
Name: "OpenRouter – Claude 3.5 Sonnet",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.5-sonnet",
|
||||
CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude35Sonnet].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude35Sonnet].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude35Sonnet].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude35Sonnet].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude35Sonnet].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude3Haiku: {
|
||||
ID: OpenRouterClaude3Haiku,
|
||||
Name: "OpenRouter – Claude 3 Haiku",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3-haiku",
|
||||
CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude3Haiku].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude3Haiku].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude3Haiku].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude3Haiku].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude3Haiku].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude37Sonnet: {
|
||||
ID: OpenRouterClaude37Sonnet,
|
||||
Name: "OpenRouter – Claude 3.7 Sonnet",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.7-sonnet",
|
||||
CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude37Sonnet].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude37Sonnet].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude37Sonnet].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude37Sonnet].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude37Sonnet].DefaultMaxTokens,
|
||||
CanReason: AnthropicModels[Claude37Sonnet].CanReason,
|
||||
},
|
||||
OpenRouterClaude35Haiku: {
|
||||
ID: OpenRouterClaude35Haiku,
|
||||
Name: "OpenRouter – Claude 3.5 Haiku",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.5-haiku",
|
||||
CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude35Haiku].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude35Haiku].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude35Haiku].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude35Haiku].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude35Haiku].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude3Opus: {
|
||||
ID: OpenRouterClaude3Opus,
|
||||
Name: "OpenRouter – Claude 3 Opus",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3-opus",
|
||||
CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude3Opus].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude3Opus].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude3Opus].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude3Opus].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens,
|
||||
},
|
||||
|
||||
OpenRouterDeepSeekR1Free: {
|
||||
ID: OpenRouterDeepSeekR1Free,
|
||||
Name: "OpenRouter – DeepSeek R1 Free",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "deepseek/deepseek-r1-0528:free",
|
||||
CostPer1MIn: 0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 163_840,
|
||||
DefaultMaxTokens: 10000,
|
||||
},
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderVertexAI InferenceProvider = "vertexai"
|
||||
|
||||
// Models
|
||||
VertexAIGemini25Flash ModelID = "vertexai.gemini-2.5-flash"
|
||||
VertexAIGemini25 ModelID = "vertexai.gemini-2.5"
|
||||
)
|
||||
|
||||
var VertexAIGeminiModels = map[ModelID]Model{
|
||||
VertexAIGemini25Flash: {
|
||||
ID: VertexAIGemini25Flash,
|
||||
Name: "VertexAI: Gemini 2.5 Flash",
|
||||
Provider: ProviderVertexAI,
|
||||
APIModel: "gemini-2.5-flash-preview-04-17",
|
||||
CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
|
||||
CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
|
||||
CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
|
||||
CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
|
||||
ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
|
||||
DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
VertexAIGemini25: {
|
||||
ID: VertexAIGemini25,
|
||||
Name: "VertexAI: Gemini 2.5 Pro",
|
||||
Provider: ProviderVertexAI,
|
||||
APIModel: "gemini-2.5-pro-preview-03-25",
|
||||
CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
|
||||
CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
|
||||
CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
|
||||
CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
|
||||
ContextWindow: GeminiModels[Gemini25].ContextWindow,
|
||||
DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderXAI InferenceProvider = "xai"
|
||||
|
||||
XAIGrok3Beta ModelID = "grok-3-beta"
|
||||
XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
|
||||
XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
|
||||
XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
|
||||
)
|
||||
|
||||
var XAIModels = map[ModelID]Model{
|
||||
XAIGrok3Beta: {
|
||||
ID: XAIGrok3Beta,
|
||||
Name: "Grok3 Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-beta",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 15,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAIGrok3MiniBeta: {
|
||||
ID: XAIGrok3MiniBeta,
|
||||
Name: "Grok3 Mini Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-mini-beta",
|
||||
CostPer1MIn: 0.3,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 0.5,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAIGrok3FastBeta: {
|
||||
ID: XAIGrok3FastBeta,
|
||||
Name: "Grok3 Fast Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-fast-beta",
|
||||
CostPer1MIn: 5,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 25,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAiGrok3MiniFastBeta: {
|
||||
ID: XAiGrok3MiniFastBeta,
|
||||
Name: "Grok3 Mini Fast Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-mini-fast-beta",
|
||||
CostPer1MIn: 0.6,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 4.0,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
}
|
||||
@@ -9,19 +9,27 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
)
|
||||
|
||||
func CoderPrompt(provider models.InferenceProvider) string {
|
||||
func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string {
|
||||
basePrompt := baseAnthropicCoderPrompt
|
||||
switch provider {
|
||||
case models.ProviderOpenAI:
|
||||
switch p {
|
||||
case provider.InferenceProviderOpenAI:
|
||||
basePrompt = baseOpenAICoderPrompt
|
||||
}
|
||||
envInfo := getEnvironmentInfo()
|
||||
|
||||
return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
|
||||
|
||||
contextContent := getContextFromPaths(contextFiles)
|
||||
logging.Debug("Context content", "Context", contextContent)
|
||||
if contextContent != "" {
|
||||
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
|
||||
}
|
||||
return basePrompt
|
||||
}
|
||||
|
||||
const baseOpenAICoderPrompt = `
|
||||
|
||||
@@ -1,60 +1,44 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
func GetAgentPrompt(agentName config.AgentName, provider models.InferenceProvider) string {
|
||||
type PromptID string
|
||||
|
||||
const (
|
||||
PromptCoder PromptID = "coder"
|
||||
PromptTitle PromptID = "title"
|
||||
PromptTask PromptID = "task"
|
||||
PromptSummarizer PromptID = "summarizer"
|
||||
PromptDefault PromptID = "default"
|
||||
)
|
||||
|
||||
func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string {
|
||||
basePrompt := ""
|
||||
switch agentName {
|
||||
case config.AgentCoder:
|
||||
switch promptID {
|
||||
case PromptCoder:
|
||||
basePrompt = CoderPrompt(provider)
|
||||
case config.AgentTitle:
|
||||
case PromptTitle:
|
||||
basePrompt = TitlePrompt(provider)
|
||||
case config.AgentTask:
|
||||
case PromptTask:
|
||||
basePrompt = TaskPrompt(provider)
|
||||
case config.AgentSummarizer:
|
||||
case PromptSummarizer:
|
||||
basePrompt = SummarizerPrompt(provider)
|
||||
default:
|
||||
basePrompt = "You are a helpful assistant"
|
||||
}
|
||||
|
||||
if agentName == config.AgentCoder || agentName == config.AgentTask {
|
||||
// Add context from project-specific instruction files if they exist
|
||||
contextContent := getContextFromPaths()
|
||||
logging.Debug("Context content", "Context", contextContent)
|
||||
if contextContent != "" {
|
||||
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
|
||||
}
|
||||
}
|
||||
return basePrompt
|
||||
}
|
||||
|
||||
var (
|
||||
onceContext sync.Once
|
||||
contextContent string
|
||||
)
|
||||
|
||||
func getContextFromPaths() string {
|
||||
onceContext.Do(func() {
|
||||
var (
|
||||
cfg = config.Get()
|
||||
workDir = cfg.WorkingDir
|
||||
contextPaths = cfg.ContextPaths
|
||||
)
|
||||
|
||||
contextContent = processContextPaths(workDir, contextPaths)
|
||||
})
|
||||
|
||||
return contextContent
|
||||
func getContextFromPaths(contextPaths []string) string {
|
||||
return processContextPaths(config.WorkingDirectory(), contextPaths)
|
||||
}
|
||||
|
||||
func processContextPaths(workDir string, paths []string) string {
|
||||
|
||||
@@ -15,16 +15,10 @@ func TestGetContextFromPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
_, err := config.Load(tmpDir, false)
|
||||
_, err := config.Init(tmpDir, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
cfg := config.Get()
|
||||
cfg.WorkingDir = tmpDir
|
||||
cfg.ContextPaths = []string{
|
||||
"file.txt",
|
||||
"directory/",
|
||||
}
|
||||
testFiles := []string{
|
||||
"file.txt",
|
||||
"directory/file_a.txt",
|
||||
@@ -34,7 +28,12 @@ func TestGetContextFromPaths(t *testing.T) {
|
||||
|
||||
createTestFiles(t, tmpDir, testFiles)
|
||||
|
||||
context := getContextFromPaths()
|
||||
context := getContextFromPaths(
|
||||
[]string{
|
||||
"file.txt",
|
||||
"directory/",
|
||||
},
|
||||
)
|
||||
expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
|
||||
assert.Equal(t, expectedContext, context)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package prompt
|
||||
|
||||
import "github.com/charmbracelet/crush/internal/llm/models"
|
||||
import (
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
func SummarizerPrompt(_ models.InferenceProvider) string {
|
||||
func SummarizerPrompt(_ provider.InferenceProvider) string {
|
||||
return `You are a helpful AI assistant tasked with summarizing conversations.
|
||||
|
||||
When asked to summarize, provide a detailed but concise summary of the conversation.
|
||||
|
||||
@@ -3,10 +3,10 @@ package prompt
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
func TaskPrompt(_ models.InferenceProvider) string {
|
||||
func TaskPrompt(_ provider.InferenceProvider) string {
|
||||
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
|
||||
Notes:
|
||||
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package prompt
|
||||
|
||||
import "github.com/charmbracelet/crush/internal/llm/models"
|
||||
import (
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
func TitlePrompt(_ models.InferenceProvider) string {
|
||||
func TitlePrompt(_ provider.InferenceProvider) string {
|
||||
return `you will generate a short title based on the first message a user begins a conversation with
|
||||
- ensure it is not more than 50 characters long
|
||||
- the title should be a summary of the user's message
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
@@ -59,7 +59,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
contentBlocks = append(contentBlocks, content)
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
base64Image := binaryContent.String(models.ProviderAnthropic)
|
||||
base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
|
||||
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
|
||||
contentBlocks = append(contentBlocks, imageBlock)
|
||||
}
|
||||
@@ -164,7 +164,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
|
||||
// }
|
||||
|
||||
return anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.providerOptions.model.APIModel),
|
||||
Model: anthropic.Model(a.providerOptions.model.ID),
|
||||
MaxTokens: a.providerOptions.maxTokens,
|
||||
Temperature: temperature,
|
||||
Messages: messages,
|
||||
@@ -184,7 +184,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
|
||||
func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
if cfg.Options.Debug {
|
||||
jsonData, _ := json.Marshal(preparedMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
@@ -233,7 +233,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
|
||||
func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
if cfg.Options.Debug {
|
||||
// jsonData, _ := json.Marshal(preparedMessages)
|
||||
// logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
@@ -19,14 +18,8 @@ type bedrockClient struct {
|
||||
type BedrockClient ProviderClient
|
||||
|
||||
func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
// Apply bedrock specific options if they are added in the future
|
||||
|
||||
// Get AWS region from environment
|
||||
region := os.Getenv("AWS_REGION")
|
||||
if region == "" {
|
||||
region = os.Getenv("AWS_DEFAULT_REGION")
|
||||
}
|
||||
|
||||
region := opts.extraParams["region"]
|
||||
if region == "" {
|
||||
region = "us-east-1" // default region
|
||||
}
|
||||
@@ -39,11 +32,11 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
|
||||
// Prefix the model name with region
|
||||
regionPrefix := region[:2]
|
||||
modelName := opts.model.APIModel
|
||||
opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
|
||||
modelName := opts.model.ID
|
||||
opts.model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
|
||||
|
||||
// Determine which provider to use based on the model
|
||||
if strings.Contains(string(opts.model.APIModel), "anthropic") {
|
||||
if strings.Contains(string(opts.model.ID), "anthropic") {
|
||||
// Create Anthropic client with Bedrock configuration
|
||||
anthropicOpts := opts
|
||||
// TODO: later find a way to check if the AWS account has caching enabled
|
||||
|
||||
@@ -157,7 +157,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
if cfg.Options.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
@@ -173,7 +173,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
if len(tools) > 0 {
|
||||
config.Tools = g.convertTools(tools)
|
||||
}
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history)
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
@@ -245,7 +245,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
if cfg.Options.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
if len(tools) > 0 {
|
||||
config.Tools = g.convertTools(tools)
|
||||
}
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, config, history)
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.ID, config, history)
|
||||
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
@@ -68,7 +68,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
|
||||
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
|
||||
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
|
||||
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
|
||||
@@ -153,7 +153,7 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
|
||||
|
||||
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(o.providerOptions.model.APIModel),
|
||||
Model: openai.ChatModel(o.providerOptions.model.ID),
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
|
||||
func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
|
||||
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
if cfg.Options.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
@@ -237,7 +237,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
}
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
if cfg.Options.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package provider
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
configv2 "github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -55,17 +55,18 @@ type Provider interface {
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() models.Model
|
||||
Model() configv2.Model
|
||||
}
|
||||
|
||||
type providerClientOptions struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
model models.Model
|
||||
model configv2.Model
|
||||
disableCache bool
|
||||
maxTokens int64
|
||||
systemMessage string
|
||||
extraHeaders map[string]string
|
||||
extraParams map[string]string
|
||||
}
|
||||
|
||||
type ProviderClientOption func(*providerClientOptions)
|
||||
@@ -80,77 +81,6 @@ type baseProvider[C ProviderClient] struct {
|
||||
client C
|
||||
}
|
||||
|
||||
func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) {
|
||||
clientOptions := providerClientOptions{}
|
||||
for _, o := range opts {
|
||||
o(&clientOptions)
|
||||
}
|
||||
switch providerName {
|
||||
case models.ProviderAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
client: newAnthropicClient(clientOptions, false),
|
||||
}, nil
|
||||
case models.ProviderOpenAI:
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderGemini:
|
||||
return &baseProvider[GeminiClient]{
|
||||
options: clientOptions,
|
||||
client: newGeminiClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderBedrock:
|
||||
return &baseProvider[BedrockClient]{
|
||||
options: clientOptions,
|
||||
client: newBedrockClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderGROQ:
|
||||
clientOptions.baseURL = "https://api.groq.com/openai/v1"
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderAzure:
|
||||
return &baseProvider[AzureClient]{
|
||||
options: clientOptions,
|
||||
client: newAzureClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderVertexAI:
|
||||
return &baseProvider[VertexAIClient]{
|
||||
options: clientOptions,
|
||||
client: newVertexAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderOpenRouter:
|
||||
clientOptions.baseURL = "https://openrouter.ai/api/v1"
|
||||
clientOptions.extraHeaders = map[string]string{
|
||||
"HTTP-Referer": "crush.charm.land",
|
||||
"X-Title": "Crush",
|
||||
}
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderXAI:
|
||||
clientOptions.baseURL = "https://api.x.ai/v1"
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderLocal:
|
||||
clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT")
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderMock:
|
||||
// TODO: implement mock client for test
|
||||
panic("not implemented")
|
||||
}
|
||||
return nil, fmt.Errorf("provider not supported: %s", providerName)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
|
||||
for _, msg := range messages {
|
||||
// The message has no content
|
||||
@@ -167,7 +97,7 @@ func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.M
|
||||
return p.client.send(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) Model() models.Model {
|
||||
func (p *baseProvider[C]) Model() configv2.Model {
|
||||
return p.options.model
|
||||
}
|
||||
|
||||
@@ -176,7 +106,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
|
||||
return p.client.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func WithModel(model models.Model) ProviderClientOption {
|
||||
func WithModel(model configv2.Model) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.model = model
|
||||
}
|
||||
@@ -199,3 +129,53 @@ func WithSystemMessage(systemMessage string) ProviderClientOption {
|
||||
options.systemMessage = systemMessage
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
|
||||
clientOptions := providerClientOptions{
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
extraHeaders: cfg.ExtraHeaders,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&clientOptions)
|
||||
}
|
||||
switch cfg.ProviderType {
|
||||
case provider.TypeAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
client: newAnthropicClient(clientOptions, false),
|
||||
}, nil
|
||||
case provider.TypeOpenAI:
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeGemini:
|
||||
return &baseProvider[GeminiClient]{
|
||||
options: clientOptions,
|
||||
client: newGeminiClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeBedrock:
|
||||
return &baseProvider[BedrockClient]{
|
||||
options: clientOptions,
|
||||
client: newBedrockClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeAzure:
|
||||
return &baseProvider[AzureClient]{
|
||||
options: clientOptions,
|
||||
client: newAzureClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeVertexAI:
|
||||
return &baseProvider[VertexAIClient]{
|
||||
options: clientOptions,
|
||||
client: newVertexAIClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeXAI:
|
||||
clientOptions.baseURL = "https://api.x.ai/v1"
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"google.golang.org/genai"
|
||||
@@ -11,9 +10,11 @@ import (
|
||||
type VertexAIClient ProviderClient
|
||||
|
||||
func newVertexAIClient(opts providerClientOptions) VertexAIClient {
|
||||
project := opts.extraHeaders["project"]
|
||||
location := opts.extraHeaders["location"]
|
||||
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
|
||||
Project: os.Getenv("GOOGLE_CLOUD_PROJECT"),
|
||||
Location: os.Getenv("GOOGLE_CLOUD_LOCATION"),
|
||||
Project: project,
|
||||
Location: location,
|
||||
Backend: genai.BackendVertexAI,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -286,7 +286,7 @@ func (c *Client) SetServerState(state ServerState) {
|
||||
// WaitForServerReady waits for the server to be ready by polling the server
|
||||
// with a simple request until it responds successfully or times out
|
||||
func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
|
||||
// Set initial state
|
||||
c.SetServerState(StateStarting)
|
||||
@@ -299,7 +299,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Waiting for LSP server to be ready...")
|
||||
}
|
||||
|
||||
@@ -308,7 +308,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
|
||||
// For TypeScript-like servers, we need to open some key files first
|
||||
if serverType == ServerTypeTypeScript {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("TypeScript-like server detected, opening key configuration files")
|
||||
}
|
||||
c.openKeyConfigFiles(ctx)
|
||||
@@ -325,7 +325,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
if err == nil {
|
||||
// Server responded successfully
|
||||
c.SetServerState(StateReady)
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("LSP server is ready")
|
||||
}
|
||||
return nil
|
||||
@@ -333,7 +333,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
}
|
||||
}
|
||||
@@ -496,7 +496,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
|
||||
|
||||
// openTypeScriptFiles finds and opens TypeScript files to help initialize the server
|
||||
func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
filesOpened := 0
|
||||
maxFilesToOpen := 5 // Limit to a reasonable number of files
|
||||
|
||||
@@ -526,7 +526,7 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
|
||||
// Try to open the file
|
||||
if err := c.OpenFile(ctx, path); err == nil {
|
||||
filesOpened++
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Opened TypeScript file for initialization", "file", path)
|
||||
}
|
||||
}
|
||||
@@ -535,11 +535,11 @@ func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && cnf.DebugLSP {
|
||||
if err != nil && cfg.Options.DebugLSP {
|
||||
logging.Debug("Error walking directory for TypeScript files", "error", err)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Opened TypeScript files for initialization", "count", filesOpened)
|
||||
}
|
||||
}
|
||||
@@ -664,7 +664,7 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
|
||||
}
|
||||
|
||||
func (c *Client) CloseFile(ctx context.Context, filepath string) error {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
uri := string(protocol.URIFromPath(filepath))
|
||||
|
||||
c.openFilesMu.Lock()
|
||||
@@ -680,7 +680,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
|
||||
},
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Closing file", "file", filepath)
|
||||
}
|
||||
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
|
||||
@@ -704,7 +704,7 @@ func (c *Client) IsFileOpen(filepath string) bool {
|
||||
|
||||
// CloseAllFiles closes all currently open files
|
||||
func (c *Client) CloseAllFiles(ctx context.Context) {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
c.openFilesMu.Lock()
|
||||
filesToClose := make([]string, 0, len(c.openFiles))
|
||||
|
||||
@@ -719,12 +719,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
|
||||
// Then close them all
|
||||
for _, filePath := range filesToClose {
|
||||
err := c.CloseFile(ctx, filePath)
|
||||
if err != nil && cnf.DebugLSP {
|
||||
if err != nil && cfg.Options.DebugLSP {
|
||||
logging.Warn("Error closing file", "file", filePath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Closed all files", "files", filesToClose)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,13 +82,13 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche
|
||||
// Notifications
|
||||
|
||||
func HandleServerMessage(params json.RawMessage) {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
var msg struct {
|
||||
Type int `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(params, &msg); err == nil {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ func WriteMessage(w io.Writer, msg *Message) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
|
||||
|
||||
// ReadMessage reads a single LSP message from the given reader
|
||||
func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
// Read headers
|
||||
var contentLength int
|
||||
for {
|
||||
@@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Received header", "line", line)
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Content-Length", "length", contentLength)
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
return nil, fmt.Errorf("failed to read content: %w", err)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Received content", "content", string(content))
|
||||
}
|
||||
|
||||
@@ -91,11 +91,11 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
|
||||
// handleMessages reads and dispatches messages in a loop
|
||||
func (c *Client) handleMessages() {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
for {
|
||||
msg, err := ReadMessage(c.stdout)
|
||||
if err != nil {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Error("Error reading message", "error", err)
|
||||
}
|
||||
return
|
||||
@@ -103,7 +103,7 @@ func (c *Client) handleMessages() {
|
||||
|
||||
// Handle server->client request (has both Method and ID)
|
||||
if msg.Method != "" && msg.ID != 0 {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
|
||||
}
|
||||
|
||||
@@ -157,11 +157,11 @@ func (c *Client) handleMessages() {
|
||||
c.notificationMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Handling notification", "method", msg.Method)
|
||||
}
|
||||
go handler(msg.Params)
|
||||
} else if cnf.DebugLSP {
|
||||
} else if cfg.Options.DebugLSP {
|
||||
logging.Debug("No handler for notification", "method", msg.Method)
|
||||
}
|
||||
continue
|
||||
@@ -174,12 +174,12 @@ func (c *Client) handleMessages() {
|
||||
c.handlersMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Received response for request", "id", msg.ID)
|
||||
}
|
||||
ch <- msg
|
||||
close(ch)
|
||||
} else if cnf.DebugLSP {
|
||||
} else if cfg.Options.DebugLSP {
|
||||
logging.Debug("No handler for response", "id", msg.ID)
|
||||
}
|
||||
}
|
||||
@@ -188,10 +188,10 @@ func (c *Client) handleMessages() {
|
||||
|
||||
// Call makes a request and waits for the response
|
||||
func (c *Client) Call(ctx context.Context, method string, params any, result any) error {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
id := c.nextID.Add(1)
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Making call", "method", method, "id", id)
|
||||
}
|
||||
|
||||
@@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Request sent", "method", method, "id", id)
|
||||
}
|
||||
|
||||
// Wait for response
|
||||
resp := <-ch
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Received response", "id", id)
|
||||
}
|
||||
|
||||
@@ -249,8 +249,8 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
|
||||
|
||||
// Notify sends a notification (a request without an ID that doesn't expect a response)
|
||||
func (c *Client) Notify(ctx context.Context, method string, params any) error {
|
||||
cnf := config.Get()
|
||||
if cnf.DebugLSP {
|
||||
cfg := config.Get()
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Sending notification", "method", method)
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher {
|
||||
|
||||
// AddRegistrations adds file watchers to track
|
||||
func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
|
||||
logging.Debug("Adding file watcher registrations")
|
||||
w.registrationMu.Lock()
|
||||
@@ -53,7 +53,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
w.registrations = append(w.registrations, watchers...)
|
||||
|
||||
// Print detailed registration information for debugging
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Adding file watcher registrations",
|
||||
"id", id,
|
||||
"watchers", len(watchers),
|
||||
@@ -122,7 +122,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
highPriorityFilesOpened := w.openHighPriorityFiles(ctx, serverName)
|
||||
filesOpened += highPriorityFilesOpened
|
||||
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Opened high-priority files",
|
||||
"count", highPriorityFilesOpened,
|
||||
"serverName", serverName)
|
||||
@@ -130,7 +130,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
|
||||
// If we've already opened enough high-priority files, we might not need more
|
||||
if filesOpened >= maxFilesToOpen {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Reached file limit with high-priority files",
|
||||
"filesOpened", filesOpened,
|
||||
"maxFiles", maxFilesToOpen)
|
||||
@@ -148,7 +148,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
// Skip directories that should be excluded
|
||||
if d.IsDir() {
|
||||
if path != w.workspacePath && shouldExcludeDir(path) {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Skipping excluded directory", "path", path)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
@@ -176,7 +176,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
})
|
||||
|
||||
elapsedTime := time.Since(startTime)
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Limited workspace scan complete",
|
||||
"filesOpened", filesOpened,
|
||||
"maxFiles", maxFilesToOpen,
|
||||
@@ -185,11 +185,11 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil && cnf.DebugLSP {
|
||||
if err != nil && cfg.Options.DebugLSP {
|
||||
logging.Debug("Error scanning workspace for files to open", "error", err)
|
||||
}
|
||||
}()
|
||||
} else if cnf.DebugLSP {
|
||||
} else if cfg.Options.DebugLSP {
|
||||
logging.Debug("Using on-demand file loading for server", "server", serverName)
|
||||
}
|
||||
}
|
||||
@@ -197,7 +197,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
// openHighPriorityFiles opens important files for the server type
|
||||
// Returns the number of files opened
|
||||
func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName string) int {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
filesOpened := 0
|
||||
|
||||
// Define patterns for high-priority files based on server type
|
||||
@@ -265,7 +265,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
|
||||
// Use doublestar.Glob to find files matching the pattern (supports ** patterns)
|
||||
matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern)
|
||||
if err != nil {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err)
|
||||
}
|
||||
continue
|
||||
@@ -299,12 +299,12 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
|
||||
for j := i; j < end; j++ {
|
||||
fullPath := filesToOpen[j]
|
||||
if err := w.client.OpenFile(ctx, fullPath); err != nil {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Error opening high-priority file", "path", fullPath, "error", err)
|
||||
}
|
||||
} else {
|
||||
filesOpened++
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Opened high-priority file", "path", fullPath)
|
||||
}
|
||||
}
|
||||
@@ -321,7 +321,7 @@ func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName
|
||||
|
||||
// WatchWorkspace sets up file watching for a workspace
|
||||
func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath string) {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
w.workspacePath = workspacePath
|
||||
|
||||
// Store the watcher in the context for later use
|
||||
@@ -356,7 +356,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
// Skip excluded directories (except workspace root)
|
||||
if d.IsDir() && path != workspacePath {
|
||||
if shouldExcludeDir(path) {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Skipping excluded directory", "path", path)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
@@ -409,7 +409,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
|
||||
}
|
||||
|
||||
// Debug logging
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
matched, kind := w.isPathWatched(event.Name)
|
||||
logging.Debug("File event",
|
||||
"path", event.Name,
|
||||
@@ -676,8 +676,8 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
|
||||
|
||||
// notifyFileEvent sends a didChangeWatchedFiles notification for a file event
|
||||
func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
|
||||
cnf := config.Get()
|
||||
if cnf.DebugLSP {
|
||||
cfg := config.Get()
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Notifying file event",
|
||||
"uri", uri,
|
||||
"changeType", changeType,
|
||||
@@ -826,7 +826,7 @@ func shouldExcludeDir(dirPath string) bool {
|
||||
// shouldExcludeFile returns true if the file should be excluded from opening
|
||||
func shouldExcludeFile(filePath string) bool {
|
||||
fileName := filepath.Base(filePath)
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
// Skip dot files
|
||||
if strings.HasPrefix(fileName, ".") {
|
||||
return true
|
||||
@@ -852,12 +852,12 @@ func shouldExcludeFile(filePath string) bool {
|
||||
|
||||
// Skip large files
|
||||
if info.Size() > maxFileSize {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Skipping large file",
|
||||
"path", filePath,
|
||||
"size", info.Size(),
|
||||
"maxSize", maxFileSize,
|
||||
"debug", cnf.Debug,
|
||||
"debug", cfg.Options.Debug,
|
||||
"sizeMB", float64(info.Size())/(1024*1024),
|
||||
"maxSizeMB", float64(maxFileSize)/(1024*1024),
|
||||
)
|
||||
@@ -870,7 +870,7 @@ func shouldExcludeFile(filePath string) bool {
|
||||
|
||||
// openMatchingFile opens a file if it matches any of the registered patterns
|
||||
func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
cnf := config.Get()
|
||||
cfg := config.Get()
|
||||
// Skip directories
|
||||
info, err := os.Stat(path)
|
||||
if err != nil || info.IsDir() {
|
||||
@@ -890,10 +890,10 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
// Check if the file is a high-priority file that should be opened immediately
|
||||
// This helps with project initialization for certain language servers
|
||||
if isHighPriorityFile(path, serverName) {
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Opening high-priority file", "path", path, "serverName", serverName)
|
||||
}
|
||||
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
|
||||
if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
|
||||
logging.Error("Error opening high-priority file", "path", path, "error", err)
|
||||
}
|
||||
return
|
||||
@@ -905,7 +905,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
|
||||
// Check file size - for preloading we're more conservative
|
||||
if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files
|
||||
if cnf.DebugLSP {
|
||||
if cfg.Options.DebugLSP {
|
||||
logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size())
|
||||
}
|
||||
return
|
||||
@@ -937,7 +937,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
|
||||
|
||||
if shouldOpen {
|
||||
// Don't need to check if it's already open - the client.OpenFile handles that
|
||||
if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP {
|
||||
if err := w.client.OpenFile(ctx, path); err != nil && cfg.Options.DebugLSP {
|
||||
logging.Error("Error opening file", "path", path, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
@@ -71,9 +71,9 @@ type BinaryContent struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (bc BinaryContent) String(provider models.InferenceProvider) string {
|
||||
func (bc BinaryContent) String(p provider.InferenceProvider) string {
|
||||
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
|
||||
if provider == models.ProviderOpenAI {
|
||||
if p == provider.InferenceProviderOpenAI {
|
||||
return "data:" + bc.MIMEType + ";base64," + base64Encoded
|
||||
}
|
||||
return base64Encoded
|
||||
@@ -113,7 +113,8 @@ type Message struct {
|
||||
Role MessageRole
|
||||
SessionID string
|
||||
Parts []ContentPart
|
||||
Model models.ModelID
|
||||
Model string
|
||||
Provider string
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
@@ -8,15 +8,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/db"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type CreateMessageParams struct {
|
||||
Role MessageRole
|
||||
Parts []ContentPart
|
||||
Model models.ModelID
|
||||
Role MessageRole
|
||||
Parts []ContentPart
|
||||
Model string
|
||||
Provider string
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
@@ -70,6 +70,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
|
||||
Role: string(params.Role),
|
||||
Parts: string(partsJSON),
|
||||
Model: sql.NullString{String: string(params.Model), Valid: true},
|
||||
Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
@@ -154,7 +155,8 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
|
||||
SessionID: item.SessionID,
|
||||
Role: MessageRole(item.Role),
|
||||
Parts: parts,
|
||||
Model: models.ModelID(item.Model.String),
|
||||
Model: item.Model.String,
|
||||
Provider: item.Provider.String,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
}, nil
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fsext"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/lsp/protocol"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
@@ -112,11 +111,7 @@ func (h *header) details() string {
|
||||
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
|
||||
}
|
||||
|
||||
cfg := config.Get()
|
||||
agentCfg := cfg.Agents[config.AgentCoder]
|
||||
selectedModelID := agentCfg.Model
|
||||
model := models.SupportedModels[selectedModelID]
|
||||
|
||||
model := config.GetAgentModel(config.AgentCoder)
|
||||
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
|
||||
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
|
||||
parts = append(parts, formattedPercentage)
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/anim"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
@@ -290,8 +291,9 @@ func (m *assistantSectionModel) View() tea.View {
|
||||
duration := finishTime.Sub(m.lastUserMessageTime)
|
||||
infoMsg := t.S().Subtle.Render(duration.String())
|
||||
icon := t.S().Subtle.Render(styles.ModelIcon)
|
||||
model := t.S().Muted.Render(models.SupportedModels[m.message.Model].Name)
|
||||
assistant := fmt.Sprintf("%s %s %s", icon, model, infoMsg)
|
||||
model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model)
|
||||
modelFormatted := t.S().Muted.Render(model.Name)
|
||||
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
|
||||
return tea.NewView(
|
||||
t.S().Base.PaddingLeft(2).Render(
|
||||
core.Section(assistant, m.width-2),
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/charmbracelet/crush/internal/diff"
|
||||
"github.com/charmbracelet/crush/internal/fsext"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/lsp/protocol"
|
||||
@@ -406,7 +405,7 @@ func (m *sidebarCmp) mcpBlock() string {
|
||||
|
||||
mcpList := []string{section, ""}
|
||||
|
||||
mcp := config.Get().MCPServers
|
||||
mcp := config.Get().MCP
|
||||
if len(mcp) == 0 {
|
||||
return lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
@@ -475,10 +474,7 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
|
||||
}
|
||||
|
||||
func (s *sidebarCmp) currentModelBlock() string {
|
||||
cfg := config.Get()
|
||||
agentCfg := cfg.Agents[config.AgentCoder]
|
||||
selectedModelID := agentCfg.Model
|
||||
model := models.SupportedModels[selectedModelID]
|
||||
model := config.GetAgentModel(config.AgentCoder)
|
||||
|
||||
t := styles.CurrentTheme()
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ func buildCommandSources(cfg *config.Config) []commandSource {
|
||||
|
||||
// Project directory
|
||||
sources = append(sources, commandSource{
|
||||
path: filepath.Join(cfg.Data.Directory, "commands"),
|
||||
path: filepath.Join(cfg.Options.DataDirectory, "commands"),
|
||||
prefix: ProjectCommandPrefix,
|
||||
})
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
configv2 "github.com/charmbracelet/crush/internal/config"
|
||||
cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
|
||||
@@ -184,7 +184,7 @@ If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (
|
||||
Add the .crush directory to the .gitignore file if it's not already there.`
|
||||
|
||||
// Mark the project as initialized
|
||||
if err := config.MarkProjectInitialized(); err != nil {
|
||||
if err := configv2.MarkProjectInitialized(); err != nil {
|
||||
return util.ReportError(err)
|
||||
}
|
||||
|
||||
@@ -196,7 +196,7 @@ Add the .crush directory to the .gitignore file if it's not already there.`
|
||||
)
|
||||
} else {
|
||||
// Mark the project as initialized without running the command
|
||||
if err := config.MarkProjectInitialized(); err != nil {
|
||||
if err := configv2.MarkProjectInitialized(); err != nil {
|
||||
return util.ReportError(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/charmbracelet/bubbles/v2/help"
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
configv2 "github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core/list"
|
||||
@@ -26,7 +24,7 @@ const (
|
||||
|
||||
// ModelSelectedMsg is sent when a model is selected
|
||||
type ModelSelectedMsg struct {
|
||||
Model models.Model
|
||||
Model configv2.PreferredModel
|
||||
}
|
||||
|
||||
// CloseModelDialogMsg is sent when a model is selected
|
||||
@@ -37,6 +35,11 @@ type ModelDialog interface {
|
||||
dialogs.DialogModel
|
||||
}
|
||||
|
||||
type ModelOption struct {
|
||||
Provider provider.Provider
|
||||
Model provider.Model
|
||||
}
|
||||
|
||||
type modelDialogCmp struct {
|
||||
width int
|
||||
wWidth int // Width of the terminal window
|
||||
@@ -80,47 +83,31 @@ func NewModelDialogCmp() ModelDialog {
|
||||
}
|
||||
}
|
||||
|
||||
var ProviderPopularity = map[models.InferenceProvider]int{
|
||||
models.ProviderAnthropic: 1,
|
||||
models.ProviderOpenAI: 2,
|
||||
models.ProviderGemini: 3,
|
||||
models.ProviderGROQ: 4,
|
||||
models.ProviderOpenRouter: 5,
|
||||
models.ProviderBedrock: 6,
|
||||
models.ProviderAzure: 7,
|
||||
models.ProviderVertexAI: 8,
|
||||
models.ProviderXAI: 9,
|
||||
}
|
||||
|
||||
var ProviderName = map[models.InferenceProvider]string{
|
||||
models.ProviderAnthropic: "Anthropic",
|
||||
models.ProviderOpenAI: "OpenAI",
|
||||
models.ProviderGemini: "Gemini",
|
||||
models.ProviderGROQ: "Groq",
|
||||
models.ProviderOpenRouter: "OpenRouter",
|
||||
models.ProviderBedrock: "AWS Bedrock",
|
||||
models.ProviderAzure: "Azure",
|
||||
models.ProviderVertexAI: "VertexAI",
|
||||
models.ProviderXAI: "xAI",
|
||||
}
|
||||
|
||||
func (m *modelDialogCmp) Init() tea.Cmd {
|
||||
cfg := config.Get()
|
||||
enabledProviders := getEnabledProviders(cfg)
|
||||
providers := configv2.Providers()
|
||||
cfg := configv2.Get()
|
||||
|
||||
coderAgent := cfg.Agents[configv2.AgentCoder]
|
||||
modelItems := []util.Model{}
|
||||
for _, provider := range enabledProviders {
|
||||
name, ok := ProviderName[provider]
|
||||
if !ok {
|
||||
name = string(provider) // Fallback to provider ID if name is not defined
|
||||
selectIndex := 0
|
||||
for _, provider := range providers {
|
||||
name := provider.Name
|
||||
if name == "" {
|
||||
name = string(provider.ID)
|
||||
}
|
||||
modelItems = append(modelItems, commands.NewItemSection(name))
|
||||
for _, model := range getModelsForProvider(provider) {
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, model))
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == coderAgent.Model && provider.ID == coderAgent.Provider {
|
||||
selectIndex = len(modelItems) // Set the selected index to the current model
|
||||
}
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}))
|
||||
}
|
||||
}
|
||||
m.modelList.SetItems(modelItems)
|
||||
return m.modelList.Init()
|
||||
|
||||
return tea.Sequence(m.modelList.Init(), m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
|
||||
}
|
||||
|
||||
func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
@@ -137,11 +124,14 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil // No item selected, do nothing
|
||||
}
|
||||
items := m.modelList.Items()
|
||||
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(models.Model)
|
||||
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
|
||||
|
||||
return m, tea.Sequence(
|
||||
util.CmdHandler(dialogs.CloseDialogMsg{}),
|
||||
util.CmdHandler(ModelSelectedMsg{Model: selectedItem}),
|
||||
util.CmdHandler(ModelSelectedMsg{Model: configv2.PreferredModel{
|
||||
ModelID: selectedItem.Model.ID,
|
||||
Provider: selectedItem.Provider.ID,
|
||||
}}),
|
||||
)
|
||||
case key.Matches(msg, m.keyMap.Close):
|
||||
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
|
||||
@@ -189,58 +179,6 @@ func (m *modelDialogCmp) listHeight() int {
|
||||
return min(listHeigh, m.wHeight/2)
|
||||
}
|
||||
|
||||
func GetSelectedModel(cfg *config.Config) models.Model {
|
||||
agentCfg := cfg.Agents[config.AgentCoder]
|
||||
selectedModelID := agentCfg.Model
|
||||
return models.SupportedModels[selectedModelID]
|
||||
}
|
||||
|
||||
func getEnabledProviders(cfg *config.Config) []models.InferenceProvider {
|
||||
var providers []models.InferenceProvider
|
||||
for providerID, provider := range cfg.Providers {
|
||||
if !provider.Disabled {
|
||||
providers = append(providers, providerID)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by provider popularity
|
||||
slices.SortFunc(providers, func(a, b models.InferenceProvider) int {
|
||||
rA := ProviderPopularity[a]
|
||||
rB := ProviderPopularity[b]
|
||||
|
||||
// models not included in popularity ranking default to last
|
||||
if rA == 0 {
|
||||
rA = 999
|
||||
}
|
||||
if rB == 0 {
|
||||
rB = 999
|
||||
}
|
||||
return rA - rB
|
||||
})
|
||||
return providers
|
||||
}
|
||||
|
||||
func getModelsForProvider(provider models.InferenceProvider) []models.Model {
|
||||
var providerModels []models.Model
|
||||
for _, model := range models.SupportedModels {
|
||||
if model.Provider == provider {
|
||||
providerModels = append(providerModels, model)
|
||||
}
|
||||
}
|
||||
|
||||
// reverse alphabetical order (if llm naming was consistent latest would appear first)
|
||||
slices.SortFunc(providerModels, func(a, b models.Model) int {
|
||||
if a.Name > b.Name {
|
||||
return -1
|
||||
} else if a.Name < b.Name {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
return providerModels
|
||||
}
|
||||
|
||||
func (m *modelDialogCmp) Position() (int, int) {
|
||||
row := m.wHeight/4 - 2 // just a bit above the center
|
||||
col := m.wWidth / 2
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/crush/internal/app"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/models"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/session"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/chat"
|
||||
@@ -171,14 +170,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
util.CmdHandler(ChatFocusedMsg{Focused: false}),
|
||||
)
|
||||
case key.Matches(msg, p.keyMap.AddAttachment):
|
||||
cfg := config.Get()
|
||||
agentCfg := cfg.Agents[config.AgentCoder]
|
||||
selectedModelID := agentCfg.Model
|
||||
model := models.SupportedModels[selectedModelID]
|
||||
if model.SupportsAttachments {
|
||||
model := config.GetAgentModel(config.AgentCoder)
|
||||
if model.SupportsImages {
|
||||
return p, util.CmdHandler(OpenFilePickerMsg{})
|
||||
} else {
|
||||
return p, util.ReportWarn("File attachments are not supported by the current model: " + string(selectedModelID))
|
||||
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name)
|
||||
}
|
||||
case key.Matches(msg, p.keyMap.Tab):
|
||||
if p.session.ID == "" {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/crush/internal/app"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
configv2 "github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/agent"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
@@ -69,7 +70,7 @@ func (a appModel) Init() tea.Cmd {
|
||||
|
||||
// Check if we should show the init dialog
|
||||
cmds = append(cmds, func() tea.Msg {
|
||||
shouldShow, err := config.ShouldShowInitDialog()
|
||||
shouldShow, err := configv2.ProjectNeedsInitialization()
|
||||
if err != nil {
|
||||
return util.InfoMsg{
|
||||
Type: util.InfoTypeError,
|
||||
@@ -172,7 +173,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// Model Switch
|
||||
case models.ModelSelectedMsg:
|
||||
model, err := a.app.CoderAgent.Update(config.AgentCoder, msg.Model.ID)
|
||||
model, err := a.app.CoderAgent.Update(msg.Model)
|
||||
if err != nil {
|
||||
return a, util.ReportError(err)
|
||||
}
|
||||
@@ -222,7 +223,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
model := a.app.CoderAgent.Model()
|
||||
contextWindow := model.ContextWindow
|
||||
tokens := session.CompletionTokens + session.PromptTokens
|
||||
if (tokens >= int64(float64(contextWindow)*0.95)) && config.Get().AutoCompact {
|
||||
if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize {
|
||||
// Show compact confirmation dialog
|
||||
cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{
|
||||
Model: compact.NewCompactDialogCmp(a.app.CoderAgent, a.selectedSessionID, false),
|
||||
|
||||
Reference in New Issue
Block a user