wip: integrate to existing app

This commit is contained in:
Kujtim Hoxha
2025-07-05 16:16:55 +02:00
parent d51853c8a9
commit 7f078a6e20
58 changed files with 525 additions and 5907 deletions

View File

@@ -7,7 +7,7 @@ import (
"slices"
"time"
"github.com/charmbracelet/crush/pkg/config"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/log/v2"
"github.com/nxadm/tail"
"github.com/spf13/cobra"

View File

@@ -1,155 +0,0 @@
package main
import (
"encoding/json"
"fmt"
"os"
"github.com/charmbracelet/crush/internal/config"
"github.com/invopop/jsonschema"
)
func main() {
// Create a new reflector
r := &jsonschema.Reflector{
// Use anonymous schemas to avoid ID conflicts
Anonymous: true,
// Expand the root struct instead of referencing it
ExpandedStruct: true,
AllowAdditionalProperties: true,
}
// Generate schema for the main Config struct
schema := r.Reflect(&config.Config{})
// Enhance the schema with additional information
enhanceSchema(schema)
// Set the schema metadata
schema.Version = "https://json-schema.org/draft/2020-12/schema"
schema.Title = "Crush Configuration"
schema.Description = "Configuration schema for the Crush application"
// Pretty print the schema
encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ")
if err := encoder.Encode(schema); err != nil {
fmt.Fprintf(os.Stderr, "Error encoding schema: %v\n", err)
os.Exit(1)
}
}
// enhanceSchema adds additional enhancements to the generated schema
func enhanceSchema(schema *jsonschema.Schema) {
// Add provider enums
addProviderEnums(schema)
// Add model enums
addModelEnums(schema)
// Add tool enums
addToolEnums(schema)
// Add default context paths
addDefaultContextPaths(schema)
}
// addProviderEnums adds provider enums to the schema
func addProviderEnums(schema *jsonschema.Schema) {
providers := config.Providers()
var providerIDs []any
for _, p := range providers {
providerIDs = append(providerIDs, string(p.ID))
}
// Add to PreferredModel provider field
if schema.Definitions != nil {
if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
if providerProp, exists := preferredModelDef.Properties.Get("provider"); exists {
providerProp.Enum = providerIDs
}
}
// Add to ProviderConfig ID field
if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists {
if idProp, exists := providerConfigDef.Properties.Get("id"); exists {
idProp.Enum = providerIDs
}
}
}
}
// addModelEnums adds model enums to the schema
func addModelEnums(schema *jsonschema.Schema) {
providers := config.Providers()
var modelIDs []any
for _, p := range providers {
for _, m := range p.Models {
modelIDs = append(modelIDs, m.ID)
}
}
// Add to PreferredModel model_id field
if schema.Definitions != nil {
if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
if modelIDProp, exists := preferredModelDef.Properties.Get("model_id"); exists {
modelIDProp.Enum = modelIDs
}
}
}
}
// addToolEnums adds tool enums to the schema
func addToolEnums(schema *jsonschema.Schema) {
tools := []any{
"bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
}
if schema.Definitions != nil {
if agentDef, exists := schema.Definitions["Agent"]; exists {
if allowedToolsProp, exists := agentDef.Properties.Get("allowed_tools"); exists {
if allowedToolsProp.Items != nil {
allowedToolsProp.Items.Enum = tools
}
}
}
}
}
// addDefaultContextPaths adds default context paths to the schema
func addDefaultContextPaths(schema *jsonschema.Schema) {
defaultContextPaths := []any{
".github/copilot-instructions.md",
".cursorrules",
".cursor/rules/",
"CLAUDE.md",
"CLAUDE.local.md",
"GEMINI.md",
"gemini.md",
"crush.md",
"crush.local.md",
"Crush.md",
"Crush.local.md",
"CRUSH.md",
"CRUSH.local.md",
}
if schema.Definitions != nil {
if optionsDef, exists := schema.Definitions["Options"]; exists {
if contextPathsProp, exists := optionsDef.Properties.Get("context_paths"); exists {
contextPathsProp.Default = defaultContextPaths
}
}
}
// Also add to root properties if they exist
if schema.Properties != nil {
if optionsProp, exists := schema.Properties.Get("options"); exists {
if optionsProp.Properties != nil {
if contextPathsProp, exists := optionsProp.Properties.Get("context_paths"); exists {
contextPathsProp.Default = defaultContextPaths
}
}
}
}
}

View File

@@ -1,700 +0,0 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$defs": {
"Agent": {
"properties": {
"id": {
"type": "string",
"enum": [
"coder",
"task",
"coder",
"task"
],
"title": "Agent ID",
"description": "Unique identifier for the agent"
},
"name": {
"type": "string",
"title": "Name",
"description": "Display name of the agent"
},
"description": {
"type": "string",
"title": "Description",
"description": "Description of what the agent does"
},
"disabled": {
"type": "boolean",
"title": "Disabled",
"description": "Whether this agent is disabled",
"default": false
},
"model": {
"type": "string",
"enum": [
"large",
"small",
"large",
"small"
],
"title": "Model Type",
"description": "Type of model to use (large or small)"
},
"allowed_tools": {
"items": {
"type": "string",
"enum": [
"bash",
"edit",
"fetch",
"glob",
"grep",
"ls",
"sourcegraph",
"view",
"write",
"agent"
]
},
"type": "array",
"title": "Allowed Tools",
"description": "List of tools this agent is allowed to use (if nil all tools are allowed)"
},
"allowed_mcp": {
"additionalProperties": {
"items": {
"type": "string"
},
"type": "array"
},
"type": "object",
"title": "Allowed MCP",
"description": "Map of MCP servers this agent can use and their allowed tools"
},
"allowed_lsp": {
"items": {
"type": "string"
},
"type": "array",
"title": "Allowed LSP",
"description": "List of LSP servers this agent can use (if nil all LSPs are allowed)"
},
"context_paths": {
"items": {
"type": "string"
},
"type": "array",
"title": "Context Paths",
"description": "Custom context paths for this agent (additive to global context paths)"
}
},
"type": "object",
"required": [
"model"
]
},
"LSPConfig": {
"properties": {
"enabled": {
"type": "boolean",
"title": "Enabled",
"description": "Whether this LSP server is enabled",
"default": true
},
"command": {
"type": "string",
"title": "Command",
"description": "Command to execute for the LSP server"
},
"args": {
"items": {
"type": "string"
},
"type": "array",
"title": "Arguments",
"description": "Command line arguments for the LSP server"
},
"options": {
"title": "Options",
"description": "LSP server specific options"
}
},
"type": "object",
"required": [
"command"
]
},
"MCP": {
"properties": {
"command": {
"type": "string",
"title": "Command",
"description": "Command to execute for stdio MCP servers"
},
"env": {
"items": {
"type": "string"
},
"type": "array",
"title": "Environment",
"description": "Environment variables for the MCP server"
},
"args": {
"items": {
"type": "string"
},
"type": "array",
"title": "Arguments",
"description": "Command line arguments for the MCP server"
},
"type": {
"type": "string",
"enum": [
"stdio",
"sse",
"stdio",
"sse",
"http"
],
"title": "Type",
"description": "Type of MCP connection",
"default": "stdio"
},
"url": {
"type": "string",
"title": "URL",
"description": "URL for SSE MCP servers"
},
"headers": {
"additionalProperties": {
"type": "string"
},
"type": "object",
"title": "Headers",
"description": "HTTP headers for SSE MCP servers"
}
},
"type": "object",
"required": [
"type"
]
},
"Model": {
"properties": {
"id": {
"type": "string",
"title": "Model ID",
"description": "Unique identifier for the model"
},
"name": {
"type": "string",
"title": "Model Name",
"description": "Display name of the model"
},
"cost_per_1m_in": {
"type": "number",
"minimum": 0,
"title": "Input Cost",
"description": "Cost per 1 million input tokens"
},
"cost_per_1m_out": {
"type": "number",
"minimum": 0,
"title": "Output Cost",
"description": "Cost per 1 million output tokens"
},
"cost_per_1m_in_cached": {
"type": "number",
"minimum": 0,
"title": "Cached Input Cost",
"description": "Cost per 1 million cached input tokens"
},
"cost_per_1m_out_cached": {
"type": "number",
"minimum": 0,
"title": "Cached Output Cost",
"description": "Cost per 1 million cached output tokens"
},
"context_window": {
"type": "integer",
"minimum": 1,
"title": "Context Window",
"description": "Maximum context window size in tokens"
},
"default_max_tokens": {
"type": "integer",
"minimum": 1,
"title": "Default Max Tokens",
"description": "Default maximum tokens for responses"
},
"can_reason": {
"type": "boolean",
"title": "Can Reason",
"description": "Whether the model supports reasoning capabilities"
},
"reasoning_effort": {
"type": "string",
"title": "Reasoning Effort",
"description": "Default reasoning effort level for reasoning models"
},
"has_reasoning_effort": {
"type": "boolean",
"title": "Has Reasoning Effort",
"description": "Whether the model supports reasoning effort configuration"
},
"supports_attachments": {
"type": "boolean",
"title": "Supports Images",
"description": "Whether the model supports image attachments"
}
},
"type": "object",
"required": [
"id",
"name",
"context_window",
"default_max_tokens"
]
},
"Options": {
"properties": {
"context_paths": {
"items": {
"type": "string"
},
"type": "array",
"title": "Context Paths",
"description": "List of paths to search for context files",
"default": [
".github/copilot-instructions.md",
".cursorrules",
".cursor/rules/",
"CLAUDE.md",
"CLAUDE.local.md",
"GEMINI.md",
"gemini.md",
"crush.md",
"crush.local.md",
"Crush.md",
"Crush.local.md",
"CRUSH.md",
"CRUSH.local.md"
]
},
"tui": {
"$ref": "#/$defs/TUIOptions",
"title": "TUI Options",
"description": "Terminal UI configuration options"
},
"debug": {
"type": "boolean",
"title": "Debug",
"description": "Enable debug logging",
"default": false
},
"debug_lsp": {
"type": "boolean",
"title": "Debug LSP",
"description": "Enable LSP debug logging",
"default": false
},
"disable_auto_summarize": {
"type": "boolean",
"title": "Disable Auto Summarize",
"description": "Disable automatic conversation summarization",
"default": false
},
"data_directory": {
"type": "string",
"title": "Data Directory",
"description": "Directory for storing application data",
"default": ".crush"
}
},
"type": "object"
},
"PreferredModel": {
"properties": {
"model_id": {
"type": "string",
"enum": [
"claude-opus-4-20250514",
"claude-sonnet-4-20250514",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"codex-mini-latest",
"o4-mini",
"o3",
"o3-pro",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4.5-preview",
"o3-mini",
"gpt-4o",
"gpt-4o-mini",
"gemini-2.5-pro",
"gemini-2.5-flash",
"codex-mini-latest",
"o4-mini",
"o3",
"o3-pro",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4.5-preview",
"o3-mini",
"gpt-4o",
"gpt-4o-mini",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",
"anthropic.claude-3-7-sonnet-20250219-v1:0",
"anthropic.claude-3-5-haiku-20241022-v1:0",
"gemini-2.5-pro",
"gemini-2.5-flash",
"grok-3-mini",
"grok-3",
"mistralai/mistral-small-3.2-24b-instruct:free",
"mistralai/mistral-small-3.2-24b-instruct",
"minimax/minimax-m1:extended",
"minimax/minimax-m1",
"google/gemini-2.5-flash-lite-preview-06-17",
"google/gemini-2.5-flash",
"google/gemini-2.5-pro",
"openai/o3-pro",
"x-ai/grok-3-mini",
"x-ai/grok-3",
"mistralai/magistral-small-2506",
"mistralai/magistral-medium-2506",
"mistralai/magistral-medium-2506:thinking",
"google/gemini-2.5-pro-preview",
"deepseek/deepseek-r1-0528",
"anthropic/claude-opus-4",
"anthropic/claude-sonnet-4",
"mistralai/devstral-small:free",
"mistralai/devstral-small",
"google/gemini-2.5-flash-preview-05-20",
"google/gemini-2.5-flash-preview-05-20:thinking",
"openai/codex-mini",
"mistralai/mistral-medium-3",
"google/gemini-2.5-pro-preview-05-06",
"arcee-ai/caller-large",
"arcee-ai/virtuoso-large",
"arcee-ai/virtuoso-medium-v2",
"qwen/qwen3-30b-a3b",
"qwen/qwen3-14b",
"qwen/qwen3-32b",
"qwen/qwen3-235b-a22b",
"google/gemini-2.5-flash-preview",
"google/gemini-2.5-flash-preview:thinking",
"openai/o4-mini-high",
"openai/o3",
"openai/o4-mini",
"openai/gpt-4.1",
"openai/gpt-4.1-mini",
"openai/gpt-4.1-nano",
"x-ai/grok-3-mini-beta",
"x-ai/grok-3-beta",
"meta-llama/llama-4-maverick",
"meta-llama/llama-4-scout",
"all-hands/openhands-lm-32b-v0.1",
"google/gemini-2.5-pro-exp-03-25",
"deepseek/deepseek-chat-v3-0324:free",
"deepseek/deepseek-chat-v3-0324",
"mistralai/mistral-small-3.1-24b-instruct:free",
"mistralai/mistral-small-3.1-24b-instruct",
"ai21/jamba-1.6-large",
"ai21/jamba-1.6-mini",
"openai/gpt-4.5-preview",
"google/gemini-2.0-flash-lite-001",
"anthropic/claude-3.7-sonnet",
"anthropic/claude-3.7-sonnet:beta",
"anthropic/claude-3.7-sonnet:thinking",
"mistralai/mistral-saba",
"openai/o3-mini-high",
"google/gemini-2.0-flash-001",
"qwen/qwen-turbo",
"qwen/qwen-plus",
"qwen/qwen-max",
"openai/o3-mini",
"mistralai/mistral-small-24b-instruct-2501",
"deepseek/deepseek-r1-distill-llama-70b",
"deepseek/deepseek-r1",
"mistralai/codestral-2501",
"deepseek/deepseek-chat",
"openai/o1",
"x-ai/grok-2-1212",
"meta-llama/llama-3.3-70b-instruct",
"amazon/nova-lite-v1",
"amazon/nova-micro-v1",
"amazon/nova-pro-v1",
"openai/gpt-4o-2024-11-20",
"mistralai/mistral-large-2411",
"mistralai/mistral-large-2407",
"mistralai/pixtral-large-2411",
"thedrummer/unslopnemo-12b",
"anthropic/claude-3.5-haiku:beta",
"anthropic/claude-3.5-haiku",
"anthropic/claude-3.5-haiku-20241022:beta",
"anthropic/claude-3.5-haiku-20241022",
"anthropic/claude-3.5-sonnet:beta",
"anthropic/claude-3.5-sonnet",
"x-ai/grok-beta",
"mistralai/ministral-8b",
"mistralai/ministral-3b",
"nvidia/llama-3.1-nemotron-70b-instruct",
"google/gemini-flash-1.5-8b",
"meta-llama/llama-3.2-11b-vision-instruct",
"meta-llama/llama-3.2-3b-instruct",
"qwen/qwen-2.5-72b-instruct",
"mistralai/pixtral-12b",
"cohere/command-r-plus-08-2024",
"cohere/command-r-08-2024",
"microsoft/phi-3.5-mini-128k-instruct",
"nousresearch/hermes-3-llama-3.1-70b",
"openai/gpt-4o-2024-08-06",
"meta-llama/llama-3.1-405b-instruct",
"meta-llama/llama-3.1-70b-instruct",
"meta-llama/llama-3.1-8b-instruct",
"mistralai/mistral-nemo",
"openai/gpt-4o-mini",
"openai/gpt-4o-mini-2024-07-18",
"anthropic/claude-3.5-sonnet-20240620:beta",
"anthropic/claude-3.5-sonnet-20240620",
"mistralai/mistral-7b-instruct-v0.3",
"mistralai/mistral-7b-instruct:free",
"mistralai/mistral-7b-instruct",
"microsoft/phi-3-mini-128k-instruct",
"microsoft/phi-3-medium-128k-instruct",
"google/gemini-flash-1.5",
"openai/gpt-4o-2024-05-13",
"openai/gpt-4o",
"openai/gpt-4o:extended",
"meta-llama/llama-3-8b-instruct",
"meta-llama/llama-3-70b-instruct",
"mistralai/mixtral-8x22b-instruct",
"openai/gpt-4-turbo",
"google/gemini-pro-1.5",
"cohere/command-r-plus",
"cohere/command-r-plus-04-2024",
"cohere/command-r",
"anthropic/claude-3-haiku:beta",
"anthropic/claude-3-haiku",
"anthropic/claude-3-opus:beta",
"anthropic/claude-3-opus",
"anthropic/claude-3-sonnet:beta",
"anthropic/claude-3-sonnet",
"cohere/command-r-03-2024",
"mistralai/mistral-large",
"openai/gpt-3.5-turbo-0613",
"openai/gpt-4-turbo-preview",
"mistralai/mistral-small",
"mistralai/mistral-tiny",
"mistralai/mixtral-8x7b-instruct",
"openai/gpt-4-1106-preview",
"mistralai/mistral-7b-instruct-v0.1",
"openai/gpt-3.5-turbo-16k",
"openai/gpt-4",
"openai/gpt-4-0314"
],
"title": "Model ID",
"description": "ID of the preferred model"
},
"provider": {
"type": "string",
"enum": [
"anthropic",
"openai",
"gemini",
"azure",
"bedrock",
"vertex",
"xai",
"openrouter"
],
"title": "Provider",
"description": "Provider for the preferred model"
},
"reasoning_effort": {
"type": "string",
"title": "Reasoning Effort",
"description": "Override reasoning effort for this model"
},
"max_tokens": {
"type": "integer",
"minimum": 1,
"title": "Max Tokens",
"description": "Override max tokens for this model"
},
"think": {
"type": "boolean",
"title": "Think",
"description": "Enable thinking for reasoning models",
"default": false
}
},
"type": "object",
"required": [
"model_id",
"provider"
]
},
"PreferredModels": {
"properties": {
"large": {
"$ref": "#/$defs/PreferredModel",
"title": "Large Model",
"description": "Preferred model configuration for large model type"
},
"small": {
"$ref": "#/$defs/PreferredModel",
"title": "Small Model",
"description": "Preferred model configuration for small model type"
}
},
"type": "object"
},
"ProviderConfig": {
"properties": {
"id": {
"type": "string",
"enum": [
"anthropic",
"openai",
"gemini",
"azure",
"bedrock",
"vertex",
"xai",
"openrouter"
],
"title": "Provider ID",
"description": "Unique identifier for the provider"
},
"base_url": {
"type": "string",
"title": "Base URL",
"description": "Base URL for the provider API (required for custom providers)"
},
"provider_type": {
"type": "string",
"title": "Provider Type",
"description": "Type of the provider (openai"
},
"api_key": {
"type": "string",
"title": "API Key",
"description": "API key for authenticating with the provider"
},
"disabled": {
"type": "boolean",
"title": "Disabled",
"description": "Whether this provider is disabled",
"default": false
},
"extra_headers": {
"additionalProperties": {
"type": "string"
},
"type": "object",
"title": "Extra Headers",
"description": "Additional HTTP headers to send with requests"
},
"extra_params": {
"additionalProperties": {
"type": "string"
},
"type": "object",
"title": "Extra Parameters",
"description": "Additional provider-specific parameters"
},
"default_large_model": {
"type": "string",
"title": "Default Large Model",
"description": "Default model ID for large model type"
},
"default_small_model": {
"type": "string",
"title": "Default Small Model",
"description": "Default model ID for small model type"
},
"models": {
"items": {
"$ref": "#/$defs/Model"
},
"type": "array",
"title": "Models",
"description": "List of available models for this provider"
}
},
"type": "object",
"required": [
"provider_type"
]
},
"TUIOptions": {
"properties": {
"compact_mode": {
"type": "boolean",
"title": "Compact Mode",
"description": "Enable compact mode for the TUI",
"default": false
}
},
"type": "object",
"required": [
"compact_mode"
]
}
},
"properties": {
"models": {
"$ref": "#/$defs/PreferredModels",
"title": "Models",
"description": "Preferred model configurations for large and small model types"
},
"providers": {
"additionalProperties": {
"$ref": "#/$defs/ProviderConfig"
},
"type": "object",
"title": "Providers",
"description": "LLM provider configurations"
},
"agents": {
"additionalProperties": {
"$ref": "#/$defs/Agent"
},
"type": "object",
"title": "Agents",
"description": "Agent configurations for different tasks"
},
"mcp": {
"additionalProperties": {
"$ref": "#/$defs/MCP"
},
"type": "object",
"title": "MCP",
"description": "Model Control Protocol server configurations"
},
"lsp": {
"additionalProperties": {
"$ref": "#/$defs/LSPConfig"
},
"type": "object",
"title": "LSP",
"description": "Language Server Protocol configurations"
},
"options": {
"$ref": "#/$defs/Options",
"title": "Options",
"description": "General application options and settings"
}
},
"type": "object",
"title": "Crush Configuration",
"description": "Configuration schema for the Crush application"
}

View File

@@ -1,5 +1,4 @@
{
"$schema": "./crush-schema.json",
"lsp": {
"go": {
"command": "gopls"

16
go.mod
View File

@@ -17,6 +17,7 @@ require (
github.com/charmbracelet/fang v0.1.0
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c
github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a
@@ -25,31 +26,28 @@ require (
github.com/go-logfmt/logfmt v0.6.0
github.com/google/uuid v1.6.0
github.com/invopop/jsonschema v0.13.0
github.com/joho/godotenv v1.5.1
github.com/mark3labs/mcp-go v0.32.0
github.com/muesli/termenv v0.16.0
github.com/ncruces/go-sqlite3 v0.25.0
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/nxadm/tail v1.4.11
github.com/openai/openai-go v1.8.2
github.com/pressly/goose/v3 v3.24.2
github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06
github.com/sahilm/fuzzy v0.1.1
github.com/spf13/cobra v1.9.1
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef
github.com/stretchr/testify v1.10.0
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
gopkg.in/natefinch/lumberjack.v2 v2.2.1
mvdan.cc/sh/v3 v3.11.0
)
require (
github.com/charmbracelet/lipgloss v1.1.0 // indirect
github.com/charmbracelet/log v0.4.2 // indirect
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/nxadm/tail v1.4.11 // indirect
github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c // indirect
github.com/spf13/cast v1.7.1 // indirect
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
)
@@ -84,7 +82,7 @@ require (
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef
github.com/charmbracelet/x/input v0.3.5-0.20250509021451-13796e822d86 // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/charmbracelet/x/term v0.2.1
github.com/charmbracelet/x/windows v0.2.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/disintegration/gift v1.1.2 // indirect

6
go.sum
View File

@@ -82,14 +82,8 @@ github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d
github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc=
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY=
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c h1:177KMz8zHRlEZJsWzafbKYh6OdjgvTspoH+UjaxgIXY=
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ=
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 h1:X0tsNa2UHCKNw+illiavosasVzqioRo32SRV35iwr2I=
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ=
github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 h1:WkwO6Ks3mSIGnGuSdKl9qDSyfbYK50z2wc2gGMggegE=
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706/go.mod h1:mjJGp00cxcfvD5xdCa+bso251Jt4owrQvuimJtVmEmM=
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 h1:L07QkDqRF274IZ2UJ/mCTL8DR95efU9BNWLYCDXEjvQ=

View File

@@ -57,7 +57,8 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
cfg := config.Get()
coderAgentCfg := cfg.Agents[config.AgentCoder]
// TODO: remove the concept of agent config most likely
coderAgentCfg := cfg.Agents["coder"]
if coderAgentCfg.ID == "" {
return nil, fmt.Errorf("coder agent configuration is missing")
}

View File

@@ -38,7 +38,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
defer cancel()
// Initialize with the initialization context
_, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
_, err = lspClient.InitializeLSPClient(initCtx, config.Get().WorkingDir())
if err != nil {
logging.Error("Initialize failed", "name", name, "error", err)
// Clean up the client to prevent resource leaks
@@ -91,7 +91,7 @@ func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceW
app.restartLSPClient(ctx, name)
})
workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
workspaceWatcher.WatchWorkspace(ctx, config.Get().WorkingDir())
logging.Info("Workspace watcher stopped", "client", name)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,71 +0,0 @@
package config
import (
"fmt"
"os"
"path/filepath"
"runtime"
)
var testConfigDir string
func baseConfigPath() string {
if testConfigDir != "" {
return testConfigDir
}
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
if xdgConfigHome != "" {
return filepath.Join(xdgConfigHome, "crush")
}
// return the path to the main config directory
// for windows, it should be in `%LOCALAPPDATA%/crush/`
// for linux and macOS, it should be in `$HOME/.config/crush/`
if runtime.GOOS == "windows" {
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
}
return filepath.Join(localAppData, appName)
}
return filepath.Join(os.Getenv("HOME"), ".config", appName)
}
func baseDataPath() string {
if testConfigDir != "" {
return testConfigDir
}
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
return filepath.Join(xdgDataHome, appName)
}
// return the path to the main data directory
// for windows, it should be in `%LOCALAPPDATA%/crush/`
// for linux and macOS, it should be in `$HOME/.local/share/crush/`
if runtime.GOOS == "windows" {
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
}
return filepath.Join(localAppData, appName)
}
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName)
}
func ConfigPath() string {
return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName))
}
func CrushInitialized() bool {
cfgPath := ConfigPath()
if _, err := os.Stat(cfgPath); os.IsNotExist(err) {
// config file does not exist, so Crush is not initialized
return false
}
return true
}

View File

@@ -5,27 +5,53 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"github.com/charmbracelet/crush/internal/logging"
)
const (
// InitFlagFilename is the name of the file that indicates whether the project has been initialized
InitFlagFilename = "init"
)
// ProjectInitFlag represents the initialization status for a project directory
type ProjectInitFlag struct {
Initialized bool `json:"initialized"`
}
// ProjectNeedsInitialization checks if the current project needs initialization
// TODO: we need to remove the global config instance keeping it now just until everything is migrated
var (
instance atomic.Pointer[Config]
cwd string
once sync.Once // Ensures the initialization happens only once
)
func Init(workingDir string, debug bool) (*Config, error) {
var err error
once.Do(func() {
cwd = workingDir
cfg, err := Load(cwd, debug)
if err != nil {
logging.Error("Failed to load config", "error", err)
}
instance.Store(cfg)
})
return instance.Load(), err
}
func Get() *Config {
return instance.Load()
}
func ProjectNeedsInitialization() (bool, error) {
if instance == nil {
cfg := Get()
if cfg == nil {
return false, fmt.Errorf("config not loaded")
}
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
// Check if the flag file exists
_, err := os.Stat(flagFilePath)
if err == nil {
return false, nil
@@ -35,8 +61,7 @@ func ProjectNeedsInitialization() (bool, error) {
return false, fmt.Errorf("failed to check init flag file: %w", err)
}
// Check if any variation of CRUSH.md already exists in working directory
crushExists, err := crushMdExists(WorkingDirectory())
crushExists, err := crushMdExists(cfg.WorkingDir())
if err != nil {
return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err)
}
@@ -47,7 +72,6 @@ func ProjectNeedsInitialization() (bool, error) {
return true, nil
}
// crushMdExists checks if any case variation of crush.md exists in the directory
func crushMdExists(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
@@ -68,12 +92,12 @@ func crushMdExists(dir string) (bool, error) {
return false, nil
}
// MarkProjectInitialized marks the current project as initialized
func MarkProjectInitialized() error {
if instance == nil {
cfg := Get()
if cfg == nil {
return fmt.Errorf("config not loaded")
}
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
file, err := os.Create(flagFilePath)
if err != nil {

View File

@@ -10,10 +10,10 @@ import (
"slices"
"strings"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/pkg/env"
"github.com/charmbracelet/crush/pkg/log"
"github.com/charmbracelet/crush/internal/log"
"golang.org/x/exp/slog"
)
@@ -68,6 +68,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
env := env.New()
// Configure providers
valueResolver := NewShellVariableResolver(env)
cfg.resolver = valueResolver
if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
return nil, fmt.Errorf("failed to configure providers: %w", err)
}
@@ -81,6 +82,36 @@ func Load(workingDir string, debug bool) (*Config, error) {
return nil, fmt.Errorf("failed to configure selected models: %w", err)
}
// TODO: remove the agents concept from the config
agents := map[string]Agent{
"coder": {
ID: "coder",
Name: "Coder",
Description: "An agent that helps with executing coding tasks.",
Model: SelectedModelTypeLarge,
ContextPaths: cfg.Options.ContextPaths,
// All tools allowed
},
"task": {
ID: "task",
Name: "Task",
Description: "An agent that helps with searching for context and finding implementation details.",
Model: SelectedModelTypeLarge,
ContextPaths: cfg.Options.ContextPaths,
AllowedTools: []string{
"glob",
"grep",
"ls",
"sourcegraph",
"view",
},
// NO MCPs or LSPs by default
AllowedMCP: map[string][]string{},
AllowedLSP: []string{},
},
}
cfg.Agents = agents
return cfg, nil
}

View File

@@ -8,7 +8,7 @@ import (
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/pkg/env"
"github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert"
)

View File

@@ -4,27 +4,44 @@ import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"sync"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
)
var fur = client.New()
var (
providerOnc sync.Once // Ensures the initialization happens only once
providerList []provider.Provider
// UseMockProviders can be set to true in tests to avoid API calls
UseMockProviders bool
)
func providersPath() string {
return filepath.Join(baseDataPath(), "providers.json")
type ProviderClient interface {
GetProviders() ([]provider.Provider, error)
}
func saveProviders(providers []provider.Provider) error {
path := providersPath()
var (
providerOnce sync.Once
providerList []provider.Provider
)
// file to cache provider data
func providerCacheFileData() string {
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
return filepath.Join(xdgDataHome, appName)
}
// return the path to the main data directory
// for windows, it should be in `%LOCALAPPDATA%/crush/`
// for linux and macOS, it should be in `$HOME/.local/share/crush/`
if runtime.GOOS == "windows" {
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
}
return filepath.Join(localAppData, appName)
}
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
}
func saveProvidersInCache(path string, providers []provider.Provider) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
@@ -38,8 +55,7 @@ func saveProviders(providers []provider.Provider) error {
return os.WriteFile(path, data, 0o644)
}
func loadProviders() ([]provider.Provider, error) {
path := providersPath()
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
@@ -50,34 +66,33 @@ func loadProviders() ([]provider.Provider, error) {
return providers, err
}
func Providers() []provider.Provider {
providerOnc.Do(func() {
// Use mock providers when testing
if UseMockProviders {
providerList = MockProviders()
return
func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
providers, err := client.GetProviders()
if err != nil {
fallbackToCache, err := loadProvidersFromCache(path)
if err != nil {
return nil, err
}
providers = fallbackToCache
} else {
if err := saveProvidersInCache(path, providerList); err != nil {
return nil, err
}
}
return providers, nil
}
// Try to get providers from upstream API
if providers, err := fur.GetProviders(); err == nil {
providerList = providers
// Save providers locally for future fallback
_ = saveProviders(providers)
} else {
// If upstream fails, try to load from local cache
if localProviders, localErr := loadProviders(); localErr == nil {
providerList = localProviders
} else {
// If both fail, return empty list
providerList = []provider.Provider{}
}
}
func Providers() ([]provider.Provider, error) {
return LoadProviders(client.New())
}
func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
var err error
providerOnce.Do(func() {
providerList, err = loadProviders(providerCacheFileData(), client)
})
return providerList
}
// ResetProviders resets the provider cache. Useful for testing.
func ResetProviders() {
providerOnc = sync.Once{}
providerList = nil
if err != nil {
return nil, err
}
return providerList, nil
}

View File

@@ -1,293 +0,0 @@
package config
import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
// MockProviders returns a mock list of providers for testing.
// This avoids making API calls during tests and provides consistent test data.
// Simplified version with only default models from each provider.
func MockProviders() []provider.Provider {
return []provider.Provider{
{
Name: "Anthropic",
ID: provider.InferenceProviderAnthropic,
APIKey: "$ANTHROPIC_API_KEY",
APIEndpoint: "$ANTHROPIC_API_ENDPOINT",
Type: provider.TypeAnthropic,
DefaultLargeModelID: "claude-sonnet-4-20250514",
DefaultSmallModelID: "claude-3-5-haiku-20241022",
Models: []provider.Model{
{
ID: "claude-sonnet-4-20250514",
Name: "Claude Sonnet 4",
CostPer1MIn: 3.0,
CostPer1MOut: 15.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.3,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsImages: true,
},
{
ID: "claude-3-5-haiku-20241022",
Name: "Claude 3.5 Haiku",
CostPer1MIn: 0.8,
CostPer1MOut: 4.0,
CostPer1MInCached: 1.0,
CostPer1MOutCached: 0.08,
ContextWindow: 200000,
DefaultMaxTokens: 5000,
CanReason: false,
SupportsImages: true,
},
},
},
{
Name: "OpenAI",
ID: provider.InferenceProviderOpenAI,
APIKey: "$OPENAI_API_KEY",
APIEndpoint: "$OPENAI_API_ENDPOINT",
Type: provider.TypeOpenAI,
DefaultLargeModelID: "codex-mini-latest",
DefaultSmallModelID: "gpt-4o",
Models: []provider.Model{
{
ID: "codex-mini-latest",
Name: "Codex Mini",
CostPer1MIn: 1.5,
CostPer1MOut: 6.0,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.375,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
HasReasoningEffort: true,
DefaultReasoningEffort: "medium",
SupportsImages: true,
},
{
ID: "gpt-4o",
Name: "GPT-4o",
CostPer1MIn: 2.5,
CostPer1MOut: 10.0,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 1.25,
ContextWindow: 128000,
DefaultMaxTokens: 20000,
CanReason: false,
SupportsImages: true,
},
},
},
{
Name: "Google Gemini",
ID: provider.InferenceProviderGemini,
APIKey: "$GEMINI_API_KEY",
APIEndpoint: "$GEMINI_API_ENDPOINT",
Type: provider.TypeGemini,
DefaultLargeModelID: "gemini-2.5-pro",
DefaultSmallModelID: "gemini-2.5-flash",
Models: []provider.Model{
{
ID: "gemini-2.5-pro",
Name: "Gemini 2.5 Pro",
CostPer1MIn: 1.25,
CostPer1MOut: 10.0,
CostPer1MInCached: 1.625,
CostPer1MOutCached: 0.31,
ContextWindow: 1048576,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsImages: true,
},
{
ID: "gemini-2.5-flash",
Name: "Gemini 2.5 Flash",
CostPer1MIn: 0.3,
CostPer1MOut: 2.5,
CostPer1MInCached: 0.3833,
CostPer1MOutCached: 0.075,
ContextWindow: 1048576,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsImages: true,
},
},
},
{
Name: "xAI",
ID: provider.InferenceProviderXAI,
APIKey: "$XAI_API_KEY",
APIEndpoint: "https://api.x.ai/v1",
Type: provider.TypeXAI,
DefaultLargeModelID: "grok-3",
DefaultSmallModelID: "grok-3-mini",
Models: []provider.Model{
{
ID: "grok-3",
Name: "Grok 3",
CostPer1MIn: 3.0,
CostPer1MOut: 15.0,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.75,
ContextWindow: 131072,
DefaultMaxTokens: 20000,
CanReason: false,
SupportsImages: false,
},
{
ID: "grok-3-mini",
Name: "Grok 3 Mini",
CostPer1MIn: 0.3,
CostPer1MOut: 0.5,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.075,
ContextWindow: 131072,
DefaultMaxTokens: 20000,
CanReason: true,
SupportsImages: false,
},
},
},
{
Name: "Azure OpenAI",
ID: provider.InferenceProviderAzure,
APIKey: "$AZURE_OPENAI_API_KEY",
APIEndpoint: "$AZURE_OPENAI_API_ENDPOINT",
Type: provider.TypeAzure,
DefaultLargeModelID: "o4-mini",
DefaultSmallModelID: "gpt-4o",
Models: []provider.Model{
{
ID: "o4-mini",
Name: "o4 Mini",
CostPer1MIn: 1.1,
CostPer1MOut: 4.4,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.275,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
HasReasoningEffort: false,
DefaultReasoningEffort: "medium",
SupportsImages: true,
},
{
ID: "gpt-4o",
Name: "GPT-4o",
CostPer1MIn: 2.5,
CostPer1MOut: 10.0,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 1.25,
ContextWindow: 128000,
DefaultMaxTokens: 20000,
CanReason: false,
SupportsImages: true,
},
},
},
{
Name: "AWS Bedrock",
ID: provider.InferenceProviderBedrock,
Type: provider.TypeBedrock,
DefaultLargeModelID: "anthropic.claude-sonnet-4-20250514-v1:0",
DefaultSmallModelID: "anthropic.claude-3-5-haiku-20241022-v1:0",
Models: []provider.Model{
{
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
Name: "AWS Claude Sonnet 4",
CostPer1MIn: 3.0,
CostPer1MOut: 15.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.3,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsImages: true,
},
{
ID: "anthropic.claude-3-5-haiku-20241022-v1:0",
Name: "AWS Claude 3.5 Haiku",
CostPer1MIn: 0.8,
CostPer1MOut: 4.0,
CostPer1MInCached: 1.0,
CostPer1MOutCached: 0.08,
ContextWindow: 200000,
DefaultMaxTokens: 50000,
CanReason: false,
SupportsImages: true,
},
},
},
{
Name: "Google Vertex AI",
ID: provider.InferenceProviderVertexAI,
Type: provider.TypeVertexAI,
DefaultLargeModelID: "gemini-2.5-pro",
DefaultSmallModelID: "gemini-2.5-flash",
Models: []provider.Model{
{
ID: "gemini-2.5-pro",
Name: "Gemini 2.5 Pro",
CostPer1MIn: 1.25,
CostPer1MOut: 10.0,
CostPer1MInCached: 1.625,
CostPer1MOutCached: 0.31,
ContextWindow: 1048576,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsImages: true,
},
{
ID: "gemini-2.5-flash",
Name: "Gemini 2.5 Flash",
CostPer1MIn: 0.3,
CostPer1MOut: 2.5,
CostPer1MInCached: 0.3833,
CostPer1MOutCached: 0.075,
ContextWindow: 1048576,
DefaultMaxTokens: 50000,
CanReason: true,
SupportsImages: true,
},
},
},
{
Name: "OpenRouter",
ID: provider.InferenceProviderOpenRouter,
APIKey: "$OPENROUTER_API_KEY",
APIEndpoint: "https://openrouter.ai/api/v1",
Type: provider.TypeOpenAI,
DefaultLargeModelID: "anthropic/claude-sonnet-4",
DefaultSmallModelID: "anthropic/claude-haiku-3.5",
Models: []provider.Model{
{
ID: "anthropic/claude-sonnet-4",
Name: "Anthropic: Claude Sonnet 4",
CostPer1MIn: 3.0,
CostPer1MOut: 15.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.3,
ContextWindow: 200000,
DefaultMaxTokens: 32000,
CanReason: true,
SupportsImages: true,
},
{
ID: "anthropic/claude-haiku-3.5",
Name: "Anthropic: Claude 3.5 Haiku",
CostPer1MIn: 0.8,
CostPer1MOut: 4.0,
CostPer1MInCached: 1.0,
CostPer1MOutCached: 0.08,
ContextWindow: 200000,
DefaultMaxTokens: 4096,
CanReason: false,
SupportsImages: true,
},
},
},
}
}

View File

@@ -1,81 +1,73 @@
package config
import (
"encoding/json"
"errors"
"os"
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProviders_MockEnabled(t *testing.T) {
originalUseMock := UseMockProviders
UseMockProviders = true
defer func() {
UseMockProviders = originalUseMock
ResetProviders()
}()
ResetProviders()
providers := Providers()
require.NotEmpty(t, providers)
providerIDs := make(map[provider.InferenceProvider]bool)
for _, p := range providers {
providerIDs[p.ID] = true
}
assert.True(t, providerIDs[provider.InferenceProviderAnthropic])
assert.True(t, providerIDs[provider.InferenceProviderOpenAI])
assert.True(t, providerIDs[provider.InferenceProviderGemini])
type mockProviderClient struct {
shouldFail bool
}
func TestProviders_ResetFunctionality(t *testing.T) {
UseMockProviders = true
defer func() {
UseMockProviders = false
ResetProviders()
}()
providers1 := Providers()
require.NotEmpty(t, providers1)
ResetProviders()
providers2 := Providers()
require.NotEmpty(t, providers2)
assert.Equal(t, len(providers1), len(providers2))
func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
if m.shouldFail {
return nil, errors.New("failed to load providers")
}
return []provider.Provider{
{
Name: "Mock",
},
}, nil
}
func TestProviders_ModelCapabilities(t *testing.T) {
originalUseMock := UseMockProviders
UseMockProviders = true
defer func() {
UseMockProviders = originalUseMock
ResetProviders()
}()
func TestProvider_loadProvidersNoIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: false}
tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(tmpPath, client)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
ResetProviders()
providers := Providers()
var openaiProvider provider.Provider
for _, p := range providers {
if p.ID == provider.InferenceProviderOpenAI {
openaiProvider = p
break
}
}
require.NotEmpty(t, openaiProvider.ID)
var foundReasoning, foundNonReasoning bool
for _, model := range openaiProvider.Models {
if model.CanReason && model.HasReasoningEffort {
foundReasoning = true
} else if !model.CanReason {
foundNonReasoning = true
}
}
assert.True(t, foundReasoning)
assert.True(t, foundNonReasoning)
// check if file got saved
fileInfo, err := os.Stat(tmpPath)
assert.NoError(t, err)
assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
}
func TestProvider_loadProvidersWithIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
// store providers to a temporary file
oldProviders := []provider.Provider{
{
Name: "OldProvider",
},
}
data, err := json.Marshal(oldProviders)
if err != nil {
t.Fatalf("Failed to marshal old providers: %v", err)
}
err = os.WriteFile(tmpPath, data, 0o644)
if err != nil {
t.Fatalf("Failed to write old providers to file: %v", err)
}
providers, err := loadProviders(tmpPath, client)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
}
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(tmpPath, client)
assert.Error(t, err)
assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
}

View File

@@ -7,7 +7,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/shell"
"github.com/charmbracelet/crush/pkg/env"
"github.com/charmbracelet/crush/internal/env"
)
type VariableResolver interface {

View File

@@ -5,7 +5,7 @@ import (
"errors"
"testing"
"github.com/charmbracelet/crush/pkg/env"
"github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert"
)

View File

@@ -1,73 +0,0 @@
package config
import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/shell"
)
// ExecuteCommand executes a shell command and returns the output
// This is a shared utility that can be used by both provider config and tools
func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) {
if workingDir == "" {
workingDir = WorkingDirectory()
}
persistentShell := shell.NewShell(&shell.Options{WorkingDir: workingDir})
stdout, stderr, err := persistentShell.Exec(ctx, command)
if err != nil {
logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr)
return "", fmt.Errorf("command execution failed: %w", err)
}
return strings.TrimSpace(stdout), nil
}
// ResolveAPIKey resolves an API key that can be either:
// - A direct string value
// - An environment variable (prefixed with $)
// - A shell command (wrapped in $(...))
func ResolveAPIKey(apiKey string) (string, error) {
if !strings.HasPrefix(apiKey, "$") {
return apiKey, nil
}
if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") {
command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")")
logging.Debug("Resolving API key from command", "command", command)
return resolveCommandAPIKey(command)
}
envVar := strings.TrimPrefix(apiKey, "$")
if value := os.Getenv(envVar); value != "" {
logging.Debug("Resolved environment variable", "envVar", envVar, "value", value)
return value, nil
}
logging.Debug("Environment variable not found", "envVar", envVar)
return "", fmt.Errorf("environment variable %s not found", envVar)
}
// resolveCommandAPIKey executes a command to get an API key, with caching support
func resolveCommandAPIKey(command string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
logging.Debug("Executing command for API key", "command", command)
workingDir := WorkingDirectory()
result, err := ExecuteCommand(ctx, command, workingDir)
if err != nil {
return "", fmt.Errorf("failed to execute API key command: %w", err)
}
logging.Debug("Command executed successfully", "command", command, "result", result)
return result, nil
}

View File

@@ -1,462 +0,0 @@
package config
import (
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConfig_Validate_ValidConfig(t *testing.T) {
cfg := &Config{
Models: PreferredModels{
Large: PreferredModel{
ModelID: "gpt-4",
Provider: provider.InferenceProviderOpenAI,
},
Small: PreferredModel{
ModelID: "gpt-3.5-turbo",
Provider: provider.InferenceProviderOpenAI,
},
},
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
APIKey: "test-key",
ProviderType: provider.TypeOpenAI,
DefaultLargeModel: "gpt-4",
DefaultSmallModel: "gpt-3.5-turbo",
Models: []Model{
{
ID: "gpt-4",
Name: "GPT-4",
ContextWindow: 8192,
DefaultMaxTokens: 4096,
CostPer1MIn: 30.0,
CostPer1MOut: 60.0,
},
{
ID: "gpt-3.5-turbo",
Name: "GPT-3.5 Turbo",
ContextWindow: 4096,
DefaultMaxTokens: 2048,
CostPer1MIn: 1.5,
CostPer1MOut: 2.0,
},
},
},
},
Agents: map[AgentID]Agent{
AgentCoder: {
ID: AgentCoder,
Name: "Coder",
Description: "An agent that helps with executing coding tasks.",
Model: LargeModel,
ContextPaths: []string{"CRUSH.md"},
},
AgentTask: {
ID: AgentTask,
Name: "Task",
Description: "An agent that helps with searching for context and finding implementation details.",
Model: LargeModel,
ContextPaths: []string{"CRUSH.md"},
AllowedTools: []string{"glob", "grep", "ls", "sourcegraph", "view"},
AllowedMCP: map[string][]string{},
AllowedLSP: []string{},
},
},
MCP: map[string]MCP{},
LSP: map[string]LSPConfig{},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
assert.NoError(t, err)
}
func TestConfig_Validate_MissingAPIKey(t *testing.T) {
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
ProviderType: provider.TypeOpenAI,
// Missing APIKey
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "API key is required")
}
func TestConfig_Validate_InvalidProviderType(t *testing.T) {
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
APIKey: "test-key",
ProviderType: provider.Type("invalid"),
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid provider type")
}
func TestConfig_Validate_CustomProviderMissingBaseURL(t *testing.T) {
customProvider := provider.InferenceProvider("custom-provider")
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
customProvider: {
ID: customProvider,
APIKey: "test-key",
ProviderType: provider.TypeOpenAI,
// Missing BaseURL for custom provider
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "BaseURL is required for custom providers")
}
func TestConfig_Validate_DuplicateModelIDs(t *testing.T) {
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
APIKey: "test-key",
ProviderType: provider.TypeOpenAI,
Models: []Model{
{
ID: "gpt-4",
Name: "GPT-4",
ContextWindow: 8192,
DefaultMaxTokens: 4096,
},
{
ID: "gpt-4", // Duplicate ID
Name: "GPT-4 Duplicate",
ContextWindow: 8192,
DefaultMaxTokens: 4096,
},
},
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "duplicate model ID")
}
func TestConfig_Validate_InvalidModelFields(t *testing.T) {
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
APIKey: "test-key",
ProviderType: provider.TypeOpenAI,
Models: []Model{
{
ID: "", // Empty ID
Name: "GPT-4",
ContextWindow: 0, // Invalid context window
DefaultMaxTokens: -1, // Invalid max tokens
CostPer1MIn: -5.0, // Negative cost
},
},
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
validationErr := err.(ValidationErrors)
assert.True(t, len(validationErr) >= 4) // Should have multiple validation errors
}
func TestConfig_Validate_DefaultModelNotFound(t *testing.T) {
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
APIKey: "test-key",
ProviderType: provider.TypeOpenAI,
DefaultLargeModel: "nonexistent-model",
Models: []Model{
{
ID: "gpt-4",
Name: "GPT-4",
ContextWindow: 8192,
DefaultMaxTokens: 4096,
},
},
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "default large model 'nonexistent-model' not found")
}
func TestConfig_Validate_AgentIDMismatch(t *testing.T) {
cfg := &Config{
Agents: map[AgentID]Agent{
AgentCoder: {
ID: AgentTask, // Wrong ID
Name: "Coder",
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "agent ID mismatch")
}
func TestConfig_Validate_InvalidAgentModelType(t *testing.T) {
cfg := &Config{
Agents: map[AgentID]Agent{
AgentCoder: {
ID: AgentCoder,
Name: "Coder",
Model: ModelType("invalid"),
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid model type")
}
func TestConfig_Validate_UnknownTool(t *testing.T) {
cfg := &Config{
Agents: map[AgentID]Agent{
AgentID("custom-agent"): {
ID: AgentID("custom-agent"),
Name: "Custom Agent",
Model: LargeModel,
AllowedTools: []string{"unknown-tool"},
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown tool")
}
func TestConfig_Validate_MCPReference(t *testing.T) {
cfg := &Config{
Agents: map[AgentID]Agent{
AgentID("custom-agent"): {
ID: AgentID("custom-agent"),
Name: "Custom Agent",
Model: LargeModel,
AllowedMCP: map[string][]string{"nonexistent-mcp": nil},
},
},
MCP: map[string]MCP{}, // Empty MCP map
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "referenced MCP 'nonexistent-mcp' not found")
}
func TestConfig_Validate_InvalidMCPType(t *testing.T) {
cfg := &Config{
MCP: map[string]MCP{
"test-mcp": {
Type: MCPType("invalid"),
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid MCP type")
}
func TestConfig_Validate_MCPMissingCommand(t *testing.T) {
cfg := &Config{
MCP: map[string]MCP{
"test-mcp": {
Type: MCPStdio,
// Missing Command
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "command is required for stdio MCP")
}
func TestConfig_Validate_LSPMissingCommand(t *testing.T) {
cfg := &Config{
LSP: map[string]LSPConfig{
"test-lsp": {
// Missing Command
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "command is required for LSP")
}
func TestConfig_Validate_NoValidProviders(t *testing.T) {
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
APIKey: "test-key",
ProviderType: provider.TypeOpenAI,
Disabled: true, // Disabled
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one non-disabled provider is required")
}
func TestConfig_Validate_MissingDefaultAgents(t *testing.T) {
cfg := &Config{
Providers: map[provider.InferenceProvider]ProviderConfig{
provider.InferenceProviderOpenAI: {
ID: provider.InferenceProviderOpenAI,
APIKey: "test-key",
ProviderType: provider.TypeOpenAI,
},
},
Agents: map[AgentID]Agent{}, // Missing default agents
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "coder agent is required")
assert.Contains(t, err.Error(), "task agent is required")
}
func TestConfig_Validate_KnownAgentProtection(t *testing.T) {
cfg := &Config{
Agents: map[AgentID]Agent{
AgentCoder: {
ID: AgentCoder,
Name: "Modified Coder", // Should not be allowed
Description: "Modified description", // Should not be allowed
Model: LargeModel,
},
},
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "coder agent name cannot be changed")
assert.Contains(t, err.Error(), "coder agent description cannot be changed")
}
func TestConfig_Validate_EmptyDataDirectory(t *testing.T) {
cfg := &Config{
Options: Options{
DataDirectory: "", // Empty
ContextPaths: []string{"CRUSH.md"},
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "data directory is required")
}
func TestConfig_Validate_EmptyContextPath(t *testing.T) {
cfg := &Config{
Options: Options{
DataDirectory: ".crush",
ContextPaths: []string{""}, // Empty context path
},
}
err := cfg.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "context path cannot be empty")
}

View File

@@ -11,7 +11,7 @@ import (
func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) {
// remove the cwd prefix and ensure consistent path format
// this prevents issues with absolute paths in different environments
cwd := config.WorkingDirectory()
cwd := config.Get().WorkingDir()
fileName = strings.TrimPrefix(fileName, cwd)
fileName = strings.TrimPrefix(fileName, "/")

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
fur "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
@@ -49,7 +50,7 @@ type AgentEvent struct {
type Service interface {
pubsub.Suscriber[AgentEvent]
Model() config.Model
Model() fur.Model
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
CancelAll()
@@ -76,9 +77,9 @@ type agent struct {
activeRequests sync.Map
}
var agentPromptMap = map[config.AgentID]prompt.PromptID{
config.AgentCoder: prompt.PromptCoder,
config.AgentTask: prompt.PromptTask,
var agentPromptMap = map[string]prompt.PromptID{
"coder": prompt.PromptCoder,
"task": prompt.PromptTask,
}
func NewAgent(
@@ -109,8 +110,8 @@ func NewAgent(
tools.NewWriteTool(lspClients, permissions, history),
}
if agentCfg.ID == config.AgentCoder {
taskAgentCfg := config.Get().Agents[config.AgentTask]
if agentCfg.ID == "coder" {
taskAgentCfg := config.Get().Agents["task"]
if taskAgentCfg.ID == "" {
return nil, fmt.Errorf("task agent not found in config")
}
@@ -130,13 +131,13 @@ func NewAgent(
}
allTools = append(allTools, otherTools...)
providerCfg := config.GetAgentProvider(agentCfg.ID)
if providerCfg.ID == "" {
providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
if providerCfg == nil {
return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
}
model := config.GetAgentModel(agentCfg.ID)
model := config.Get().GetModelByType(agentCfg.Model)
if model.ID == "" {
if model == nil {
return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
}
@@ -148,51 +149,40 @@ func NewAgent(
provider.WithModel(agentCfg.Model),
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
}
agentProvider, err := provider.NewProvider(providerCfg, opts...)
agentProvider, err := provider.NewProvider(*providerCfg, opts...)
if err != nil {
return nil, err
}
smallModelCfg := cfg.Models.Small
var smallModel config.Model
var smallModelProviderCfg config.ProviderConfig
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg *config.ProviderConfig
if smallModelCfg.Provider == providerCfg.ID {
smallModelProviderCfg = providerCfg
} else {
for _, p := range cfg.Providers {
if p.ID == smallModelCfg.Provider {
smallModelProviderCfg = p
break
}
}
smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
if smallModelProviderCfg.ID == "" {
return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
}
}
for _, m := range smallModelProviderCfg.Models {
if m.ID == smallModelCfg.ModelID {
smallModel = m
break
}
}
smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
if smallModel.ID == "" {
return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
}
titleOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
}
titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
if err != nil {
return nil, err
}
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
}
summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
if err != nil {
return nil, err
}
@@ -225,8 +215,8 @@ func NewAgent(
return agent, nil
}
func (a *agent) Model() config.Model {
return config.GetAgentModel(a.agentCfg.ID)
func (a *agent) Model() fur.Model {
return *config.Get().GetModelByType(a.agentCfg.Model)
}
func (a *agent) Cancel(sessionID string) {
@@ -610,7 +600,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
return nil
}
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
@@ -819,7 +809,7 @@ func (a *agent) UpdateModel() error {
cfg := config.Get()
// Get current provider configuration
currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
if currentProviderCfg.ID == "" {
return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
}
@@ -827,7 +817,7 @@ func (a *agent) UpdateModel() error {
// Check if provider has changed
if string(currentProviderCfg.ID) != a.providerID {
// Provider changed, need to recreate the main provider
model := config.GetAgentModel(a.agentCfg.ID)
model := cfg.GetModelByType(a.agentCfg.Model)
if model.ID == "" {
return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
}
@@ -842,7 +832,7 @@ func (a *agent) UpdateModel() error {
provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
}
newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
if err != nil {
return fmt.Errorf("failed to create new provider: %w", err)
}
@@ -853,7 +843,7 @@ func (a *agent) UpdateModel() error {
}
// Check if small model provider has changed (affects title and summarize providers)
smallModelCfg := cfg.Models.Small
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg config.ProviderConfig
for _, p := range cfg.Providers {
@@ -869,20 +859,14 @@ func (a *agent) UpdateModel() error {
// Check if summarize provider has changed
if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
var smallModel config.Model
for _, m := range smallModelProviderCfg.Models {
if m.ID == smallModelCfg.ModelID {
smallModel = m
break
}
}
if smallModel.ID == "" {
return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
if smallModel == nil {
return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
}
// Recreate title provider
titleOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
// We want the title to be short, so we limit the max tokens
provider.WithMaxTokens(40),
@@ -894,7 +878,7 @@ func (a *agent) UpdateModel() error {
// Recreate summarize provider
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(config.SmallModel),
provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
}
newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)

View File

@@ -19,7 +19,7 @@ import (
type mcpTool struct {
mcpName string
tool mcp.Tool
mcpConfig config.MCP
mcpConfig config.MCPConfig
permissions permission.Service
}
@@ -97,7 +97,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
p := b.permissions.Request(
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: config.WorkingDirectory(),
Path: config.Get().WorkingDir(),
ToolName: b.Info().Name,
Action: "execute",
Description: permissionDescription,
@@ -142,7 +142,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("invalid mcp type"), nil
}
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool {
return &mcpTool{
mcpName: name,
tool: tool,
@@ -153,7 +153,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
var mcpTools []tools.BaseTool
func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION

View File

@@ -14,12 +14,12 @@ import (
"github.com/charmbracelet/crush/internal/logging"
)
func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string {
func CoderPrompt(p string, contextFiles ...string) string {
var basePrompt string
switch p {
case provider.InferenceProviderOpenAI:
case string(provider.InferenceProviderOpenAI):
basePrompt = baseOpenAICoderPrompt
case provider.InferenceProviderGemini, provider.InferenceProviderVertexAI:
case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI):
basePrompt = baseGeminiCoderPrompt
default:
basePrompt = baseAnthropicCoderPrompt
@@ -380,7 +380,7 @@ Your core function is efficient and safe assistance. Balance extreme conciseness
`
func getEnvironmentInfo() string {
cwd := config.WorkingDirectory()
cwd := config.Get().WorkingDir()
isGit := isGitRepo(cwd)
platform := runtime.GOOS
date := time.Now().Format("1/2/2006")

View File

@@ -7,7 +7,6 @@ import (
"sync"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
)
type PromptID string
@@ -20,17 +19,17 @@ const (
PromptDefault PromptID = "default"
)
func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string {
func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string {
basePrompt := ""
switch promptID {
case PromptCoder:
basePrompt = CoderPrompt(provider)
case PromptTitle:
basePrompt = TitlePrompt(provider)
basePrompt = TitlePrompt()
case PromptTask:
basePrompt = TaskPrompt(provider)
basePrompt = TaskPrompt()
case PromptSummarizer:
basePrompt = SummarizerPrompt(provider)
basePrompt = SummarizerPrompt()
default:
basePrompt = "You are a helpful assistant"
}
@@ -38,7 +37,7 @@ func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPa
}
func getContextFromPaths(contextPaths []string) string {
return processContextPaths(config.WorkingDirectory(), contextPaths)
return processContextPaths(config.Get().WorkingDir(), contextPaths)
}
func processContextPaths(workDir string, paths []string) string {

View File

@@ -1,10 +1,6 @@
package prompt
import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
func SummarizerPrompt(_ provider.InferenceProvider) string {
func SummarizerPrompt() string {
return `You are a helpful AI assistant tasked with summarizing conversations.
When asked to summarize, provide a detailed but concise summary of the conversation.

View File

@@ -2,11 +2,9 @@ package prompt
import (
"fmt"
"github.com/charmbracelet/crush/internal/fur/provider"
)
func TaskPrompt(_ provider.InferenceProvider) string {
func TaskPrompt() string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".

View File

@@ -1,10 +1,6 @@
package prompt
import (
"github.com/charmbracelet/crush/internal/fur/provider"
)
func TitlePrompt(_ provider.InferenceProvider) string {
func TitlePrompt() string {
return `you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 50 characters long
- the title should be a summary of the user's message

View File

@@ -153,9 +153,9 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
model := a.providerOptions.model(a.providerOptions.modelType)
var thinkingParam anthropic.ThinkingConfigParamUnion
cfg := config.Get()
modelConfig := cfg.Models.Large
if a.providerOptions.modelType == config.SmallModel {
modelConfig = cfg.Models.Small
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if a.providerOptions.modelType == config.SelectedModelTypeSmall {
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
temperature := anthropic.Float(0)
@@ -399,7 +399,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
}
if apiErr.StatusCode == 401 {
a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -490,6 +490,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
}
}
func (a *anthropicClient) Model() config.Model {
func (a *anthropicClient) Model() provider.Model {
return a.providerOptions.model(a.providerOptions.modelType)
}

View File

@@ -7,6 +7,7 @@ import (
"strings"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -31,14 +32,14 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
}
}
opts.model = func(modelType config.ModelType) config.Model {
model := config.GetModel(modelType)
opts.model = func(modelType config.SelectedModelType) provider.Model {
model := config.Get().GetModelByType(modelType)
// Prefix the model name with region
regionPrefix := region[:2]
modelName := model.ID
model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
return model
return *model
}
model := opts.model(opts.modelType)
@@ -87,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
return b.childProvider.stream(ctx, messages, tools)
}
func (b *bedrockClient) Model() config.Model {
func (b *bedrockClient) Model() provider.Model {
return b.providerOptions.model(b.providerOptions.modelType)
}

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -170,9 +171,9 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
logging.Debug("Prepared messages", "messages", string(jsonData))
}
modelConfig := cfg.Models.Large
if g.providerOptions.modelType == config.SmallModel {
modelConfig = cfg.Models.Small
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if g.providerOptions.modelType == config.SelectedModelTypeSmall {
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
maxTokens := model.DefaultMaxTokens
@@ -268,9 +269,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
logging.Debug("Prepared messages", "messages", string(jsonData))
}
modelConfig := cfg.Models.Large
if g.providerOptions.modelType == config.SmallModel {
modelConfig = cfg.Models.Small
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if g.providerOptions.modelType == config.SelectedModelTypeSmall {
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
maxTokens := model.DefaultMaxTokens
if modelConfig.MaxTokens > 0 {
@@ -424,7 +425,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
// Check for token expiration (401 Unauthorized)
if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey)
g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -462,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
}
}
func (g *geminiClient) Model() config.Model {
func (g *geminiClient) Model() provider.Model {
return g.providerOptions.model(g.providerOptions.modelType)
}

View File

@@ -148,15 +148,12 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
model := o.providerOptions.model(o.providerOptions.modelType)
cfg := config.Get()
modelConfig := cfg.Models.Large
if o.providerOptions.modelType == config.SmallModel {
modelConfig = cfg.Models.Small
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if o.providerOptions.modelType == config.SelectedModelTypeSmall {
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
}
reasoningEffort := model.ReasoningEffort
if modelConfig.ReasoningEffort != "" {
reasoningEffort = modelConfig.ReasoningEffort
}
reasoningEffort := modelConfig.ReasoningEffort
params := openai.ChatCompletionNewParams{
Model: openai.ChatModel(model.ID),
@@ -363,7 +360,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
// Check for token expiration (401 Unauthorized)
if apiErr.StatusCode == 401 {
o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey)
o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -420,6 +417,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
}
}
func (a *openaiClient) Model() config.Model {
func (a *openaiClient) Model() provider.Model {
return a.providerOptions.model(a.providerOptions.modelType)
}

View File

@@ -55,15 +55,15 @@ type Provider interface {
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() config.Model
Model() provider.Model
}
type providerClientOptions struct {
baseURL string
config config.ProviderConfig
apiKey string
modelType config.ModelType
model func(config.ModelType) config.Model
modelType config.SelectedModelType
model func(config.SelectedModelType) provider.Model
disableCache bool
systemMessage string
maxTokens int64
@@ -77,7 +77,7 @@ type ProviderClient interface {
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() config.Model
Model() provider.Model
}
type baseProvider[C ProviderClient] struct {
@@ -106,11 +106,11 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
return p.client.stream(ctx, messages, tools)
}
func (p *baseProvider[C]) Model() config.Model {
func (p *baseProvider[C]) Model() provider.Model {
return p.client.Model()
}
func WithModel(model config.ModelType) ProviderClientOption {
func WithModel(model config.SelectedModelType) ProviderClientOption {
return func(options *providerClientOptions) {
options.modelType = model
}
@@ -135,7 +135,7 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
}
func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey)
resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
if err != nil {
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
}
@@ -145,14 +145,14 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
config: cfg,
apiKey: resolvedAPIKey,
extraHeaders: cfg.ExtraHeaders,
model: func(tp config.ModelType) config.Model {
return config.GetModel(tp)
model: func(tp config.SelectedModelType) provider.Model {
return *config.Get().GetModelByType(tp)
},
}
for _, o := range opts {
o(&clientOptions)
}
switch cfg.ProviderType {
switch cfg.Type {
case provider.TypeAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
@@ -190,5 +190,5 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
client: newOpenAIClient(clientOptions),
}, nil
}
return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
}

View File

@@ -317,7 +317,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
p := b.permissions.Request(
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: config.WorkingDirectory(),
Path: config.Get().WorkingDir(),
ToolName: BashToolName,
Action: "execute",
Description: fmt.Sprintf("Execute command: %s", params.Command),
@@ -337,7 +337,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
defer cancel()
}
stdout, stderr, err := shell.
GetPersistentShell(config.WorkingDirectory()).
GetPersistentShell(config.Get().WorkingDir()).
Exec(ctx, params.Command)
interrupted := shell.IsInterrupt(err)
exitCode := shell.ExitCode(err)

View File

@@ -143,7 +143,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
}
if !filepath.IsAbs(params.FilePath) {
wd := config.WorkingDirectory()
wd := config.Get().WorkingDir()
params.FilePath = filepath.Join(wd, params.FilePath)
}
@@ -207,7 +207,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
content,
filePath,
)
rootDir := config.WorkingDirectory()
rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir
@@ -320,7 +320,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
filePath,
)
rootDir := config.WorkingDirectory()
rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir
@@ -442,7 +442,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
newContent,
filePath,
)
rootDir := config.WorkingDirectory()
rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir

View File

@@ -133,7 +133,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
p := t.permissions.Request(
permission.CreatePermissionRequest{
SessionID: sessionID,
Path: config.WorkingDirectory(),
Path: config.Get().WorkingDir(),
ToolName: FetchToolName,
Action: "fetch",
Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),

View File

@@ -108,7 +108,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
searchPath := params.Path
if searchPath == "" {
searchPath = config.WorkingDirectory()
searchPath = config.Get().WorkingDir()
}
files, truncated, err := globFiles(params.Pattern, searchPath, 100)

View File

@@ -200,7 +200,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
searchPath := params.Path
if searchPath == "" {
searchPath = config.WorkingDirectory()
searchPath = config.Get().WorkingDir()
}
matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100)

View File

@@ -107,11 +107,11 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
searchPath := params.Path
if searchPath == "" {
searchPath = config.WorkingDirectory()
searchPath = config.Get().WorkingDir()
}
if !filepath.IsAbs(searchPath) {
searchPath = filepath.Join(config.WorkingDirectory(), searchPath)
searchPath = filepath.Join(config.Get().WorkingDir(), searchPath)
}
if _, err := os.Stat(searchPath); os.IsNotExist(err) {

View File

@@ -117,7 +117,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
// Handle relative paths
filePath := params.FilePath
if !filepath.IsAbs(filePath) {
filePath = filepath.Join(config.WorkingDirectory(), filePath)
filePath = filepath.Join(config.Get().WorkingDir(), filePath)
}
// Check if file exists

View File

@@ -122,7 +122,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
filePath := params.FilePath
if !filepath.IsAbs(filePath) {
filePath = filepath.Join(config.WorkingDirectory(), filePath)
filePath = filepath.Join(config.Get().WorkingDir(), filePath)
}
fileInfo, err := os.Stat(filePath)
@@ -170,7 +170,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
filePath,
)
rootDir := config.WorkingDirectory()
rootDir := config.Get().WorkingDir()
permissionPath := filepath.Dir(filePath)
if strings.HasPrefix(filePath, rootDir) {
permissionPath = rootDir

View File

@@ -376,7 +376,7 @@ func (c *Client) detectServerType() ServerType {
// openKeyConfigFiles opens important configuration files that help initialize the server
func (c *Client) openKeyConfigFiles(ctx context.Context) {
workDir := config.WorkingDirectory()
workDir := config.Get().WorkingDir()
serverType := c.detectServerType()
var filesToOpen []string
@@ -464,7 +464,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
}
// If we have no open TypeScript files, try to find and open one
workDir := config.WorkingDirectory()
workDir := config.Get().WorkingDir()
err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err

View File

@@ -87,7 +87,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
dir := filepath.Dir(opts.Path)
if dir == "." {
dir = config.WorkingDirectory()
dir = config.Get().WorkingDir()
}
permission := PermissionRequest{
ID: uuid.New().String(),

View File

@@ -91,7 +91,7 @@ func (p *header) View() tea.View {
func (h *header) details() string {
t := styles.CurrentTheme()
cwd := fsext.DirTrim(fsext.PrettyPath(config.WorkingDirectory()), 4)
cwd := fsext.DirTrim(fsext.PrettyPath(config.Get().WorkingDir()), 4)
parts := []string{
t.S().Muted.Render(cwd),
}
@@ -111,7 +111,8 @@ func (h *header) details() string {
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
}
model := config.GetAgentModel(config.AgentCoder)
agentCfg := config.Get().Agents["coder"]
model := config.Get().GetModelByType(agentCfg.Model)
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
parts = append(parts, formattedPercentage)

View File

@@ -10,7 +10,6 @@ import (
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/tui/components/anim"
"github.com/charmbracelet/crush/internal/tui/components/core"
@@ -296,7 +295,7 @@ func (m *assistantSectionModel) View() tea.View {
duration := finishTime.Sub(m.lastUserMessageTime)
infoMsg := t.S().Subtle.Render(duration.String())
icon := t.S().Subtle.Render(styles.ModelIcon)
model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model)
model := config.Get().GetModel(m.message.Provider, m.message.Model)
modelFormatted := t.S().Muted.Render(model.Name)
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
return tea.NewView(

View File

@@ -297,7 +297,7 @@ func (m *sidebarCmp) filesBlock() string {
}
extraContent := strings.Join(statusParts, " ")
cwd := config.WorkingDirectory() + string(os.PathSeparator)
cwd := config.Get().WorkingDir() + string(os.PathSeparator)
filePath := file.FilePath
filePath = strings.TrimPrefix(filePath, cwd)
filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2)
@@ -474,7 +474,8 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
}
func (s *sidebarCmp) currentModelBlock() string {
model := config.GetAgentModel(config.AgentCoder)
agentCfg := config.Get().Agents["coder"]
model := config.Get().GetModelByType(agentCfg.Model)
t := styles.CurrentTheme()
@@ -507,7 +508,7 @@ func (m *sidebarCmp) SetSession(session session.Session) tea.Cmd {
}
func cwd() string {
cwd := config.WorkingDirectory()
cwd := config.Get().WorkingDir()
t := styles.CurrentTheme()
// Replace home directory with ~, unless we're at the top level of the
// home directory).

View File

@@ -31,8 +31,8 @@ const (
// ModelSelectedMsg is sent when a model is selected
type ModelSelectedMsg struct {
Model config.PreferredModel
ModelType config.ModelType
Model config.SelectedModel
ModelType config.SelectedModelType
}
// CloseModelDialogMsg is sent when a model is selected
@@ -115,19 +115,19 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
items := m.modelList.Items()
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
var modelType config.ModelType
var modelType config.SelectedModelType
if m.modelType == LargeModelType {
modelType = config.LargeModel
modelType = config.SelectedModelTypeLarge
} else {
modelType = config.SmallModel
modelType = config.SelectedModelTypeSmall
}
return m, tea.Sequence(
util.CmdHandler(dialogs.CloseDialogMsg{}),
util.CmdHandler(ModelSelectedMsg{
Model: config.PreferredModel{
ModelID: selectedItem.Model.ID,
Provider: selectedItem.Provider.ID,
Model: config.SelectedModel{
Model: selectedItem.Model.ID,
Provider: string(selectedItem.Provider.ID),
},
ModelType: modelType,
}),
@@ -218,35 +218,39 @@ func (m *modelDialogCmp) modelTypeRadio() string {
func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
m.modelType = modelType
providers := config.Providers()
providers, err := config.Providers()
if err != nil {
return util.ReportError(err)
}
modelItems := []util.Model{}
selectIndex := 0
cfg := config.Get()
var currentModel config.PreferredModel
var currentModel config.SelectedModel
if m.modelType == LargeModelType {
currentModel = cfg.Models.Large
currentModel = cfg.Models[config.SelectedModelTypeLarge]
} else {
currentModel = cfg.Models.Small
currentModel = cfg.Models[config.SelectedModelTypeSmall]
}
// Create a map to track which providers we've already added
addedProviders := make(map[provider.InferenceProvider]bool)
addedProviders := make(map[string]bool)
// First, add any configured providers that are not in the known providers list
// These should appear at the top of the list
knownProviders := provider.KnownProviders()
for providerID, providerConfig := range cfg.Providers {
if providerConfig.Disabled {
if providerConfig.Disable {
continue
}
// Check if this provider is not in the known providers list
if !slices.Contains(knownProviders, providerID) {
if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
// Convert config provider to provider.Provider format
configProvider := provider.Provider{
Name: string(providerID), // Use provider ID as name for unknown providers
ID: providerID,
ID: provider.InferenceProvider(providerID),
Models: make([]provider.Model, len(providerConfig.Models)),
}
@@ -263,7 +267,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
DefaultMaxTokens: model.DefaultMaxTokens,
CanReason: model.CanReason,
HasReasoningEffort: model.HasReasoningEffort,
DefaultReasoningEffort: model.ReasoningEffort,
DefaultReasoningEffort: model.DefaultReasoningEffort,
SupportsImages: model.SupportsImages,
}
}
@@ -279,7 +283,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
Provider: configProvider,
Model: model,
}))
if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
selectIndex = len(modelItems) - 1 // Set the selected index to the current model
}
}
@@ -290,12 +294,12 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
// Then add the known providers from the predefined list
for _, provider := range providers {
// Skip if we already added this provider as an unknown provider
if addedProviders[provider.ID] {
if addedProviders[string(provider.ID)] {
continue
}
// Check if this provider is configured and not disabled
if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
continue
}
@@ -309,7 +313,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
Provider: provider,
Model: model,
}))
if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
selectIndex = len(modelItems) - 1 // Set the selected index to the current model
}
}

View File

@@ -170,7 +170,8 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
util.CmdHandler(ChatFocusedMsg{Focused: false}),
)
case key.Matches(msg, p.keyMap.AddAttachment):
model := config.GetAgentModel(config.AgentCoder)
agentCfg := config.Get().Agents["coder"]
model := config.Get().GetModelByType(agentCfg.Model)
if model.SupportsImages {
return p, util.CmdHandler(OpenFilePickerMsg{})
} else {

View File

@@ -177,14 +177,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Update the agent with the new model/provider configuration
if err := a.app.UpdateAgentModel(); err != nil {
logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err))
return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err))
return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.Model, err))
}
modelTypeName := "large"
if msg.ModelType == config.SmallModel {
if msg.ModelType == config.SelectedModelTypeSmall {
modelTypeName = "small"
}
return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.ModelID))
return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.Model))
// File Picker
case chat.OpenFilePickerMsg:

View File

@@ -1,224 +0,0 @@
package config
import (
"slices"
"strings"
"github.com/charmbracelet/crush/internal/fur/provider"
)
const (
appName = "crush"
defaultDataDirectory = ".crush"
defaultLogLevel = "info"
)
var defaultContextPaths = []string{
".github/copilot-instructions.md",
".cursorrules",
".cursor/rules/",
"CLAUDE.md",
"CLAUDE.local.md",
"GEMINI.md",
"gemini.md",
"crush.md",
"crush.local.md",
"Crush.md",
"Crush.local.md",
"CRUSH.md",
"CRUSH.local.md",
}
type SelectedModelType string
const (
SelectedModelTypeLarge SelectedModelType = "large"
SelectedModelTypeSmall SelectedModelType = "small"
)
type SelectedModel struct {
// The model id as used by the provider API.
// Required.
Model string `json:"model"`
// The model provider, same as the key/id used in the providers config.
// Required.
Provider string `json:"provider"`
// Only used by models that use the openai provider and need this set.
ReasoningEffort string `json:"reasoning_effort,omitempty"`
// Overrides the default model configuration.
MaxTokens int64 `json:"max_tokens,omitempty"`
// Used by anthropic models that can reason to indicate if the model should think.
Think bool `json:"think,omitempty"`
}
type ProviderConfig struct {
// The provider's id.
ID string `json:"id,omitempty"`
// The provider's API endpoint.
BaseURL string `json:"base_url,omitempty"`
// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
Type provider.Type `json:"type,omitempty"`
// The provider's API key.
APIKey string `json:"api_key,omitempty"`
// Marks the provider as disabled.
Disable bool `json:"disable,omitempty"`
// Extra headers to send with each request to the provider.
ExtraHeaders map[string]string
// Used to pass extra parameters to the provider.
ExtraParams map[string]string `json:"-"`
// The provider models
Models []provider.Model `json:"models,omitempty"`
}
type MCPType string
const (
MCPStdio MCPType = "stdio"
MCPSse MCPType = "sse"
MCPHttp MCPType = "http"
)
type MCPConfig struct {
Command string `json:"command,omitempty" `
Env []string `json:"env,omitempty"`
Args []string `json:"args,omitempty"`
Type MCPType `json:"type"`
URL string `json:"url,omitempty"`
// TODO: maybe make it possible to get the value from the env
Headers map[string]string `json:"headers,omitempty"`
}
type LSPConfig struct {
Disabled bool `json:"enabled,omitempty"`
Command string `json:"command"`
Args []string `json:"args,omitempty"`
Options any `json:"options,omitempty"`
}
type TUIOptions struct {
CompactMode bool `json:"compact_mode,omitempty"`
// Here we can add themes later or any TUI related options
}
type Options struct {
ContextPaths []string `json:"context_paths,omitempty"`
TUI *TUIOptions `json:"tui,omitempty"`
Debug bool `json:"debug,omitempty"`
DebugLSP bool `json:"debug_lsp,omitempty"`
DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
// Relative to the cwd
DataDirectory string `json:"data_directory,omitempty"`
}
type MCPs map[string]MCPConfig
type MCP struct {
Name string `json:"name"`
MCP MCPConfig `json:"mcp"`
}
func (m MCPs) Sorted() []MCP {
sorted := make([]MCP, 0, len(m))
for k, v := range m {
sorted = append(sorted, MCP{
Name: k,
MCP: v,
})
}
slices.SortFunc(sorted, func(a, b MCP) int {
return strings.Compare(a.Name, b.Name)
})
return sorted
}
type LSPs map[string]LSPConfig
type LSP struct {
Name string `json:"name"`
LSP LSPConfig `json:"lsp"`
}
func (l LSPs) Sorted() []LSP {
sorted := make([]LSP, 0, len(l))
for k, v := range l {
sorted = append(sorted, LSP{
Name: k,
LSP: v,
})
}
slices.SortFunc(sorted, func(a, b LSP) int {
return strings.Compare(a.Name, b.Name)
})
return sorted
}
// Config holds the configuration for crush.
type Config struct {
// We currently only support large/small as values here.
Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
// The providers that are configured
Providers map[string]ProviderConfig `json:"providers,omitempty"`
MCP MCPs `json:"mcp,omitempty"`
LSP LSPs `json:"lsp,omitempty"`
Options *Options `json:"options,omitempty"`
// Internal
workingDir string `json:"-"`
}
func (c *Config) WorkingDir() string {
return c.workingDir
}
func (c *Config) EnabledProviders() []ProviderConfig {
enabled := make([]ProviderConfig, 0, len(c.Providers))
for _, p := range c.Providers {
if !p.Disable {
enabled = append(enabled, p)
}
}
return enabled
}
// IsConfigured return true if at least one provider is configured
func (c *Config) IsConfigured() bool {
return len(c.EnabledProviders()) > 0
}
func (c *Config) GetModel(provider, model string) *provider.Model {
if providerConfig, ok := c.Providers[provider]; ok {
for _, m := range providerConfig.Models {
if m.ID == model {
return &m
}
}
}
return nil
}
func (c *Config) LargeModel() *provider.Model {
model, ok := c.Models[SelectedModelTypeLarge]
if !ok {
return nil
}
return c.GetModel(model.Provider, model.Model)
}
func (c *Config) SmallModel() *provider.Model {
model, ok := c.Models[SelectedModelTypeSmall]
if !ok {
return nil
}
return c.GetModel(model.Provider, model.Model)
}

View File

@@ -1,93 +0,0 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"sync"
"github.com/charmbracelet/crush/internal/fur/provider"
)
type ProviderClient interface {
GetProviders() ([]provider.Provider, error)
}
var (
providerOnce sync.Once
providerList []provider.Provider
)
// file to cache provider data
func providerCacheFileData() string {
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
return filepath.Join(xdgDataHome, appName)
}
// return the path to the main data directory
// for windows, it should be in `%LOCALAPPDATA%/crush/`
// for linux and macOS, it should be in `$HOME/.local/share/crush/`
if runtime.GOOS == "windows" {
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
}
return filepath.Join(localAppData, appName)
}
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
}
func saveProvidersInCache(path string, providers []provider.Provider) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(providers, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var providers []provider.Provider
err = json.Unmarshal(data, &providers)
return providers, err
}
func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
providers, err := client.GetProviders()
if err != nil {
fallbackToCache, err := loadProvidersFromCache(path)
if err != nil {
return nil, err
}
providers = fallbackToCache
} else {
if err := saveProvidersInCache(path, providerList); err != nil {
return nil, err
}
}
return providers, nil
}
func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
var err error
providerOnce.Do(func() {
providerList, err = loadProviders(providerCacheFileData(), client)
})
if err != nil {
return nil, err
}
return providerList, nil
}

View File

@@ -1,73 +0,0 @@
package config
import (
"encoding/json"
"errors"
"os"
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/stretchr/testify/assert"
)
type mockProviderClient struct {
shouldFail bool
}
func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
if m.shouldFail {
return nil, errors.New("failed to load providers")
}
return []provider.Provider{
{
Name: "Mock",
},
}, nil
}
func TestProvider_loadProvidersNoIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: false}
tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(tmpPath, client)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
// check if file got saved
fileInfo, err := os.Stat(tmpPath)
assert.NoError(t, err)
assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
}
func TestProvider_loadProvidersWithIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
// store providers to a temporary file
oldProviders := []provider.Provider{
{
Name: "OldProvider",
},
}
data, err := json.Marshal(oldProviders)
if err != nil {
t.Fatalf("Failed to marshal old providers: %v", err)
}
err = os.WriteFile(tmpPath, data, 0o644)
if err != nil {
t.Fatalf("Failed to write old providers to file: %v", err)
}
providers, err := loadProviders(tmpPath, client)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
}
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(tmpPath, client)
assert.Error(t, err)
assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
}