mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
wip: integrate to existing app
This commit is contained in:
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/pkg/config"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/log/v2"
|
||||
"github.com/nxadm/tail"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/invopop/jsonschema"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create a new reflector
|
||||
r := &jsonschema.Reflector{
|
||||
// Use anonymous schemas to avoid ID conflicts
|
||||
Anonymous: true,
|
||||
// Expand the root struct instead of referencing it
|
||||
ExpandedStruct: true,
|
||||
AllowAdditionalProperties: true,
|
||||
}
|
||||
|
||||
// Generate schema for the main Config struct
|
||||
schema := r.Reflect(&config.Config{})
|
||||
|
||||
// Enhance the schema with additional information
|
||||
enhanceSchema(schema)
|
||||
|
||||
// Set the schema metadata
|
||||
schema.Version = "https://json-schema.org/draft/2020-12/schema"
|
||||
schema.Title = "Crush Configuration"
|
||||
schema.Description = "Configuration schema for the Crush application"
|
||||
|
||||
// Pretty print the schema
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(schema); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error encoding schema: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// enhanceSchema adds additional enhancements to the generated schema
|
||||
func enhanceSchema(schema *jsonschema.Schema) {
|
||||
// Add provider enums
|
||||
addProviderEnums(schema)
|
||||
|
||||
// Add model enums
|
||||
addModelEnums(schema)
|
||||
|
||||
// Add tool enums
|
||||
addToolEnums(schema)
|
||||
|
||||
// Add default context paths
|
||||
addDefaultContextPaths(schema)
|
||||
}
|
||||
|
||||
// addProviderEnums adds provider enums to the schema
|
||||
func addProviderEnums(schema *jsonschema.Schema) {
|
||||
providers := config.Providers()
|
||||
var providerIDs []any
|
||||
for _, p := range providers {
|
||||
providerIDs = append(providerIDs, string(p.ID))
|
||||
}
|
||||
|
||||
// Add to PreferredModel provider field
|
||||
if schema.Definitions != nil {
|
||||
if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
|
||||
if providerProp, exists := preferredModelDef.Properties.Get("provider"); exists {
|
||||
providerProp.Enum = providerIDs
|
||||
}
|
||||
}
|
||||
|
||||
// Add to ProviderConfig ID field
|
||||
if providerConfigDef, exists := schema.Definitions["ProviderConfig"]; exists {
|
||||
if idProp, exists := providerConfigDef.Properties.Get("id"); exists {
|
||||
idProp.Enum = providerIDs
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addModelEnums adds model enums to the schema
|
||||
func addModelEnums(schema *jsonschema.Schema) {
|
||||
providers := config.Providers()
|
||||
var modelIDs []any
|
||||
for _, p := range providers {
|
||||
for _, m := range p.Models {
|
||||
modelIDs = append(modelIDs, m.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Add to PreferredModel model_id field
|
||||
if schema.Definitions != nil {
|
||||
if preferredModelDef, exists := schema.Definitions["PreferredModel"]; exists {
|
||||
if modelIDProp, exists := preferredModelDef.Properties.Get("model_id"); exists {
|
||||
modelIDProp.Enum = modelIDs
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addToolEnums adds tool enums to the schema
|
||||
func addToolEnums(schema *jsonschema.Schema) {
|
||||
tools := []any{
|
||||
"bash", "edit", "fetch", "glob", "grep", "ls", "sourcegraph", "view", "write", "agent",
|
||||
}
|
||||
|
||||
if schema.Definitions != nil {
|
||||
if agentDef, exists := schema.Definitions["Agent"]; exists {
|
||||
if allowedToolsProp, exists := agentDef.Properties.Get("allowed_tools"); exists {
|
||||
if allowedToolsProp.Items != nil {
|
||||
allowedToolsProp.Items.Enum = tools
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addDefaultContextPaths adds default context paths to the schema
|
||||
func addDefaultContextPaths(schema *jsonschema.Schema) {
|
||||
defaultContextPaths := []any{
|
||||
".github/copilot-instructions.md",
|
||||
".cursorrules",
|
||||
".cursor/rules/",
|
||||
"CLAUDE.md",
|
||||
"CLAUDE.local.md",
|
||||
"GEMINI.md",
|
||||
"gemini.md",
|
||||
"crush.md",
|
||||
"crush.local.md",
|
||||
"Crush.md",
|
||||
"Crush.local.md",
|
||||
"CRUSH.md",
|
||||
"CRUSH.local.md",
|
||||
}
|
||||
|
||||
if schema.Definitions != nil {
|
||||
if optionsDef, exists := schema.Definitions["Options"]; exists {
|
||||
if contextPathsProp, exists := optionsDef.Properties.Get("context_paths"); exists {
|
||||
contextPathsProp.Default = defaultContextPaths
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also add to root properties if they exist
|
||||
if schema.Properties != nil {
|
||||
if optionsProp, exists := schema.Properties.Get("options"); exists {
|
||||
if optionsProp.Properties != nil {
|
||||
if contextPathsProp, exists := optionsProp.Properties.Get("context_paths"); exists {
|
||||
contextPathsProp.Default = defaultContextPaths
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,700 +0,0 @@
|
||||
{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"$defs": {
|
||||
"Agent": {
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"coder",
|
||||
"task",
|
||||
"coder",
|
||||
"task"
|
||||
],
|
||||
"title": "Agent ID",
|
||||
"description": "Unique identifier for the agent"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"title": "Name",
|
||||
"description": "Display name of the agent"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"title": "Description",
|
||||
"description": "Description of what the agent does"
|
||||
},
|
||||
"disabled": {
|
||||
"type": "boolean",
|
||||
"title": "Disabled",
|
||||
"description": "Whether this agent is disabled",
|
||||
"default": false
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"large",
|
||||
"small",
|
||||
"large",
|
||||
"small"
|
||||
],
|
||||
"title": "Model Type",
|
||||
"description": "Type of model to use (large or small)"
|
||||
},
|
||||
"allowed_tools": {
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"bash",
|
||||
"edit",
|
||||
"fetch",
|
||||
"glob",
|
||||
"grep",
|
||||
"ls",
|
||||
"sourcegraph",
|
||||
"view",
|
||||
"write",
|
||||
"agent"
|
||||
]
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Allowed Tools",
|
||||
"description": "List of tools this agent is allowed to use (if nil all tools are allowed)"
|
||||
},
|
||||
"allowed_mcp": {
|
||||
"additionalProperties": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Allowed MCP",
|
||||
"description": "Map of MCP servers this agent can use and their allowed tools"
|
||||
},
|
||||
"allowed_lsp": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Allowed LSP",
|
||||
"description": "List of LSP servers this agent can use (if nil all LSPs are allowed)"
|
||||
},
|
||||
"context_paths": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Context Paths",
|
||||
"description": "Custom context paths for this agent (additive to global context paths)"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model"
|
||||
]
|
||||
},
|
||||
"LSPConfig": {
|
||||
"properties": {
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"title": "Enabled",
|
||||
"description": "Whether this LSP server is enabled",
|
||||
"default": true
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"title": "Command",
|
||||
"description": "Command to execute for the LSP server"
|
||||
},
|
||||
"args": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Arguments",
|
||||
"description": "Command line arguments for the LSP server"
|
||||
},
|
||||
"options": {
|
||||
"title": "Options",
|
||||
"description": "LSP server specific options"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"command"
|
||||
]
|
||||
},
|
||||
"MCP": {
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"title": "Command",
|
||||
"description": "Command to execute for stdio MCP servers"
|
||||
},
|
||||
"env": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Environment",
|
||||
"description": "Environment variables for the MCP server"
|
||||
},
|
||||
"args": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Arguments",
|
||||
"description": "Command line arguments for the MCP server"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"stdio",
|
||||
"sse",
|
||||
"stdio",
|
||||
"sse",
|
||||
"http"
|
||||
],
|
||||
"title": "Type",
|
||||
"description": "Type of MCP connection",
|
||||
"default": "stdio"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"title": "URL",
|
||||
"description": "URL for SSE MCP servers"
|
||||
},
|
||||
"headers": {
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Headers",
|
||||
"description": "HTTP headers for SSE MCP servers"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"type"
|
||||
]
|
||||
},
|
||||
"Model": {
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"title": "Model ID",
|
||||
"description": "Unique identifier for the model"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"title": "Model Name",
|
||||
"description": "Display name of the model"
|
||||
},
|
||||
"cost_per_1m_in": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"title": "Input Cost",
|
||||
"description": "Cost per 1 million input tokens"
|
||||
},
|
||||
"cost_per_1m_out": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"title": "Output Cost",
|
||||
"description": "Cost per 1 million output tokens"
|
||||
},
|
||||
"cost_per_1m_in_cached": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"title": "Cached Input Cost",
|
||||
"description": "Cost per 1 million cached input tokens"
|
||||
},
|
||||
"cost_per_1m_out_cached": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"title": "Cached Output Cost",
|
||||
"description": "Cost per 1 million cached output tokens"
|
||||
},
|
||||
"context_window": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"title": "Context Window",
|
||||
"description": "Maximum context window size in tokens"
|
||||
},
|
||||
"default_max_tokens": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"title": "Default Max Tokens",
|
||||
"description": "Default maximum tokens for responses"
|
||||
},
|
||||
"can_reason": {
|
||||
"type": "boolean",
|
||||
"title": "Can Reason",
|
||||
"description": "Whether the model supports reasoning capabilities"
|
||||
},
|
||||
"reasoning_effort": {
|
||||
"type": "string",
|
||||
"title": "Reasoning Effort",
|
||||
"description": "Default reasoning effort level for reasoning models"
|
||||
},
|
||||
"has_reasoning_effort": {
|
||||
"type": "boolean",
|
||||
"title": "Has Reasoning Effort",
|
||||
"description": "Whether the model supports reasoning effort configuration"
|
||||
},
|
||||
"supports_attachments": {
|
||||
"type": "boolean",
|
||||
"title": "Supports Images",
|
||||
"description": "Whether the model supports image attachments"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"name",
|
||||
"context_window",
|
||||
"default_max_tokens"
|
||||
]
|
||||
},
|
||||
"Options": {
|
||||
"properties": {
|
||||
"context_paths": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Context Paths",
|
||||
"description": "List of paths to search for context files",
|
||||
"default": [
|
||||
".github/copilot-instructions.md",
|
||||
".cursorrules",
|
||||
".cursor/rules/",
|
||||
"CLAUDE.md",
|
||||
"CLAUDE.local.md",
|
||||
"GEMINI.md",
|
||||
"gemini.md",
|
||||
"crush.md",
|
||||
"crush.local.md",
|
||||
"Crush.md",
|
||||
"Crush.local.md",
|
||||
"CRUSH.md",
|
||||
"CRUSH.local.md"
|
||||
]
|
||||
},
|
||||
"tui": {
|
||||
"$ref": "#/$defs/TUIOptions",
|
||||
"title": "TUI Options",
|
||||
"description": "Terminal UI configuration options"
|
||||
},
|
||||
"debug": {
|
||||
"type": "boolean",
|
||||
"title": "Debug",
|
||||
"description": "Enable debug logging",
|
||||
"default": false
|
||||
},
|
||||
"debug_lsp": {
|
||||
"type": "boolean",
|
||||
"title": "Debug LSP",
|
||||
"description": "Enable LSP debug logging",
|
||||
"default": false
|
||||
},
|
||||
"disable_auto_summarize": {
|
||||
"type": "boolean",
|
||||
"title": "Disable Auto Summarize",
|
||||
"description": "Disable automatic conversation summarization",
|
||||
"default": false
|
||||
},
|
||||
"data_directory": {
|
||||
"type": "string",
|
||||
"title": "Data Directory",
|
||||
"description": "Directory for storing application data",
|
||||
"default": ".crush"
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
"PreferredModel": {
|
||||
"properties": {
|
||||
"model_id": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"codex-mini-latest",
|
||||
"o4-mini",
|
||||
"o3",
|
||||
"o3-pro",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"gpt-4.5-preview",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"codex-mini-latest",
|
||||
"o4-mini",
|
||||
"o3",
|
||||
"o3-pro",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"gpt-4.5-preview",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"grok-3-mini",
|
||||
"grok-3",
|
||||
"mistralai/mistral-small-3.2-24b-instruct:free",
|
||||
"mistralai/mistral-small-3.2-24b-instruct",
|
||||
"minimax/minimax-m1:extended",
|
||||
"minimax/minimax-m1",
|
||||
"google/gemini-2.5-flash-lite-preview-06-17",
|
||||
"google/gemini-2.5-flash",
|
||||
"google/gemini-2.5-pro",
|
||||
"openai/o3-pro",
|
||||
"x-ai/grok-3-mini",
|
||||
"x-ai/grok-3",
|
||||
"mistralai/magistral-small-2506",
|
||||
"mistralai/magistral-medium-2506",
|
||||
"mistralai/magistral-medium-2506:thinking",
|
||||
"google/gemini-2.5-pro-preview",
|
||||
"deepseek/deepseek-r1-0528",
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"mistralai/devstral-small:free",
|
||||
"mistralai/devstral-small",
|
||||
"google/gemini-2.5-flash-preview-05-20",
|
||||
"google/gemini-2.5-flash-preview-05-20:thinking",
|
||||
"openai/codex-mini",
|
||||
"mistralai/mistral-medium-3",
|
||||
"google/gemini-2.5-pro-preview-05-06",
|
||||
"arcee-ai/caller-large",
|
||||
"arcee-ai/virtuoso-large",
|
||||
"arcee-ai/virtuoso-medium-v2",
|
||||
"qwen/qwen3-30b-a3b",
|
||||
"qwen/qwen3-14b",
|
||||
"qwen/qwen3-32b",
|
||||
"qwen/qwen3-235b-a22b",
|
||||
"google/gemini-2.5-flash-preview",
|
||||
"google/gemini-2.5-flash-preview:thinking",
|
||||
"openai/o4-mini-high",
|
||||
"openai/o3",
|
||||
"openai/o4-mini",
|
||||
"openai/gpt-4.1",
|
||||
"openai/gpt-4.1-mini",
|
||||
"openai/gpt-4.1-nano",
|
||||
"x-ai/grok-3-mini-beta",
|
||||
"x-ai/grok-3-beta",
|
||||
"meta-llama/llama-4-maverick",
|
||||
"meta-llama/llama-4-scout",
|
||||
"all-hands/openhands-lm-32b-v0.1",
|
||||
"google/gemini-2.5-pro-exp-03-25",
|
||||
"deepseek/deepseek-chat-v3-0324:free",
|
||||
"deepseek/deepseek-chat-v3-0324",
|
||||
"mistralai/mistral-small-3.1-24b-instruct:free",
|
||||
"mistralai/mistral-small-3.1-24b-instruct",
|
||||
"ai21/jamba-1.6-large",
|
||||
"ai21/jamba-1.6-mini",
|
||||
"openai/gpt-4.5-preview",
|
||||
"google/gemini-2.0-flash-lite-001",
|
||||
"anthropic/claude-3.7-sonnet",
|
||||
"anthropic/claude-3.7-sonnet:beta",
|
||||
"anthropic/claude-3.7-sonnet:thinking",
|
||||
"mistralai/mistral-saba",
|
||||
"openai/o3-mini-high",
|
||||
"google/gemini-2.0-flash-001",
|
||||
"qwen/qwen-turbo",
|
||||
"qwen/qwen-plus",
|
||||
"qwen/qwen-max",
|
||||
"openai/o3-mini",
|
||||
"mistralai/mistral-small-24b-instruct-2501",
|
||||
"deepseek/deepseek-r1-distill-llama-70b",
|
||||
"deepseek/deepseek-r1",
|
||||
"mistralai/codestral-2501",
|
||||
"deepseek/deepseek-chat",
|
||||
"openai/o1",
|
||||
"x-ai/grok-2-1212",
|
||||
"meta-llama/llama-3.3-70b-instruct",
|
||||
"amazon/nova-lite-v1",
|
||||
"amazon/nova-micro-v1",
|
||||
"amazon/nova-pro-v1",
|
||||
"openai/gpt-4o-2024-11-20",
|
||||
"mistralai/mistral-large-2411",
|
||||
"mistralai/mistral-large-2407",
|
||||
"mistralai/pixtral-large-2411",
|
||||
"thedrummer/unslopnemo-12b",
|
||||
"anthropic/claude-3.5-haiku:beta",
|
||||
"anthropic/claude-3.5-haiku",
|
||||
"anthropic/claude-3.5-haiku-20241022:beta",
|
||||
"anthropic/claude-3.5-haiku-20241022",
|
||||
"anthropic/claude-3.5-sonnet:beta",
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
"x-ai/grok-beta",
|
||||
"mistralai/ministral-8b",
|
||||
"mistralai/ministral-3b",
|
||||
"nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
"google/gemini-flash-1.5-8b",
|
||||
"meta-llama/llama-3.2-11b-vision-instruct",
|
||||
"meta-llama/llama-3.2-3b-instruct",
|
||||
"qwen/qwen-2.5-72b-instruct",
|
||||
"mistralai/pixtral-12b",
|
||||
"cohere/command-r-plus-08-2024",
|
||||
"cohere/command-r-08-2024",
|
||||
"microsoft/phi-3.5-mini-128k-instruct",
|
||||
"nousresearch/hermes-3-llama-3.1-70b",
|
||||
"openai/gpt-4o-2024-08-06",
|
||||
"meta-llama/llama-3.1-405b-instruct",
|
||||
"meta-llama/llama-3.1-70b-instruct",
|
||||
"meta-llama/llama-3.1-8b-instruct",
|
||||
"mistralai/mistral-nemo",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/gpt-4o-mini-2024-07-18",
|
||||
"anthropic/claude-3.5-sonnet-20240620:beta",
|
||||
"anthropic/claude-3.5-sonnet-20240620",
|
||||
"mistralai/mistral-7b-instruct-v0.3",
|
||||
"mistralai/mistral-7b-instruct:free",
|
||||
"mistralai/mistral-7b-instruct",
|
||||
"microsoft/phi-3-mini-128k-instruct",
|
||||
"microsoft/phi-3-medium-128k-instruct",
|
||||
"google/gemini-flash-1.5",
|
||||
"openai/gpt-4o-2024-05-13",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o:extended",
|
||||
"meta-llama/llama-3-8b-instruct",
|
||||
"meta-llama/llama-3-70b-instruct",
|
||||
"mistralai/mixtral-8x22b-instruct",
|
||||
"openai/gpt-4-turbo",
|
||||
"google/gemini-pro-1.5",
|
||||
"cohere/command-r-plus",
|
||||
"cohere/command-r-plus-04-2024",
|
||||
"cohere/command-r",
|
||||
"anthropic/claude-3-haiku:beta",
|
||||
"anthropic/claude-3-haiku",
|
||||
"anthropic/claude-3-opus:beta",
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet:beta",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"cohere/command-r-03-2024",
|
||||
"mistralai/mistral-large",
|
||||
"openai/gpt-3.5-turbo-0613",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"mistralai/mistral-small",
|
||||
"mistralai/mistral-tiny",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
"openai/gpt-4-1106-preview",
|
||||
"mistralai/mistral-7b-instruct-v0.1",
|
||||
"openai/gpt-3.5-turbo-16k",
|
||||
"openai/gpt-4",
|
||||
"openai/gpt-4-0314"
|
||||
],
|
||||
"title": "Model ID",
|
||||
"description": "ID of the preferred model"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"gemini",
|
||||
"azure",
|
||||
"bedrock",
|
||||
"vertex",
|
||||
"xai",
|
||||
"openrouter"
|
||||
],
|
||||
"title": "Provider",
|
||||
"description": "Provider for the preferred model"
|
||||
},
|
||||
"reasoning_effort": {
|
||||
"type": "string",
|
||||
"title": "Reasoning Effort",
|
||||
"description": "Override reasoning effort for this model"
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"title": "Max Tokens",
|
||||
"description": "Override max tokens for this model"
|
||||
},
|
||||
"think": {
|
||||
"type": "boolean",
|
||||
"title": "Think",
|
||||
"description": "Enable thinking for reasoning models",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model_id",
|
||||
"provider"
|
||||
]
|
||||
},
|
||||
"PreferredModels": {
|
||||
"properties": {
|
||||
"large": {
|
||||
"$ref": "#/$defs/PreferredModel",
|
||||
"title": "Large Model",
|
||||
"description": "Preferred model configuration for large model type"
|
||||
},
|
||||
"small": {
|
||||
"$ref": "#/$defs/PreferredModel",
|
||||
"title": "Small Model",
|
||||
"description": "Preferred model configuration for small model type"
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
"ProviderConfig": {
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"gemini",
|
||||
"azure",
|
||||
"bedrock",
|
||||
"vertex",
|
||||
"xai",
|
||||
"openrouter"
|
||||
],
|
||||
"title": "Provider ID",
|
||||
"description": "Unique identifier for the provider"
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"title": "Base URL",
|
||||
"description": "Base URL for the provider API (required for custom providers)"
|
||||
},
|
||||
"provider_type": {
|
||||
"type": "string",
|
||||
"title": "Provider Type",
|
||||
"description": "Type of the provider (openai"
|
||||
},
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"title": "API Key",
|
||||
"description": "API key for authenticating with the provider"
|
||||
},
|
||||
"disabled": {
|
||||
"type": "boolean",
|
||||
"title": "Disabled",
|
||||
"description": "Whether this provider is disabled",
|
||||
"default": false
|
||||
},
|
||||
"extra_headers": {
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Extra Headers",
|
||||
"description": "Additional HTTP headers to send with requests"
|
||||
},
|
||||
"extra_params": {
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Extra Parameters",
|
||||
"description": "Additional provider-specific parameters"
|
||||
},
|
||||
"default_large_model": {
|
||||
"type": "string",
|
||||
"title": "Default Large Model",
|
||||
"description": "Default model ID for large model type"
|
||||
},
|
||||
"default_small_model": {
|
||||
"type": "string",
|
||||
"title": "Default Small Model",
|
||||
"description": "Default model ID for small model type"
|
||||
},
|
||||
"models": {
|
||||
"items": {
|
||||
"$ref": "#/$defs/Model"
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Models",
|
||||
"description": "List of available models for this provider"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"provider_type"
|
||||
]
|
||||
},
|
||||
"TUIOptions": {
|
||||
"properties": {
|
||||
"compact_mode": {
|
||||
"type": "boolean",
|
||||
"title": "Compact Mode",
|
||||
"description": "Enable compact mode for the TUI",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"compact_mode"
|
||||
]
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"models": {
|
||||
"$ref": "#/$defs/PreferredModels",
|
||||
"title": "Models",
|
||||
"description": "Preferred model configurations for large and small model types"
|
||||
},
|
||||
"providers": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/$defs/ProviderConfig"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Providers",
|
||||
"description": "LLM provider configurations"
|
||||
},
|
||||
"agents": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/$defs/Agent"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Agents",
|
||||
"description": "Agent configurations for different tasks"
|
||||
},
|
||||
"mcp": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/$defs/MCP"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "MCP",
|
||||
"description": "Model Control Protocol server configurations"
|
||||
},
|
||||
"lsp": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/$defs/LSPConfig"
|
||||
},
|
||||
"type": "object",
|
||||
"title": "LSP",
|
||||
"description": "Language Server Protocol configurations"
|
||||
},
|
||||
"options": {
|
||||
"$ref": "#/$defs/Options",
|
||||
"title": "Options",
|
||||
"description": "General application options and settings"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"title": "Crush Configuration",
|
||||
"description": "Configuration schema for the Crush application"
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
{
|
||||
"$schema": "./crush-schema.json",
|
||||
"lsp": {
|
||||
"go": {
|
||||
"command": "gopls"
|
||||
|
||||
16
go.mod
16
go.mod
@@ -17,6 +17,7 @@ require (
|
||||
github.com/charmbracelet/fang v0.1.0
|
||||
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
|
||||
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71
|
||||
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706
|
||||
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20250627134340-c144409e381c
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a
|
||||
@@ -25,31 +26,28 @@ require (
|
||||
github.com/go-logfmt/logfmt v0.6.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/invopop/jsonschema v0.13.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/mark3labs/mcp-go v0.32.0
|
||||
github.com/muesli/termenv v0.16.0
|
||||
github.com/ncruces/go-sqlite3 v0.25.0
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/nxadm/tail v1.4.11
|
||||
github.com/openai/openai-go v1.8.2
|
||||
github.com/pressly/goose/v3 v3.24.2
|
||||
github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c
|
||||
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06
|
||||
github.com/sahilm/fuzzy v0.1.1
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
mvdan.cc/sh/v3 v3.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/lipgloss v1.1.0 // indirect
|
||||
github.com/charmbracelet/log v0.4.2 // indirect
|
||||
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 // indirect
|
||||
github.com/joho/godotenv v1.5.1 // indirect
|
||||
github.com/nxadm/tail v1.4.11 // indirect
|
||||
github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c // indirect
|
||||
github.com/spf13/cast v1.7.1 // indirect
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
)
|
||||
|
||||
@@ -84,7 +82,7 @@ require (
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef
|
||||
github.com/charmbracelet/x/input v0.3.5-0.20250509021451-13796e822d86 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1
|
||||
github.com/charmbracelet/x/windows v0.2.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/disintegration/gift v1.1.2 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -82,14 +82,8 @@ github.com/charmbracelet/fang v0.1.0 h1:SlZS2crf3/zQh7Mr4+W+7QR1k+L08rrPX5rm5z3d
|
||||
github.com/charmbracelet/fang v0.1.0/go.mod h1:Zl/zeUQ8EtQuGyiV0ZKZlZPDowKRTzu8s/367EpN/fc=
|
||||
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe h1:i6ce4CcAlPpTj2ER69m1DBeLZ3RRcHnKExuwhKa3GfY=
|
||||
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe/go.mod h1:p3Q+aN4eQKeM5jhrmXPMgPrlKbmc59rWSnMsSA3udhk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c h1:177KMz8zHRlEZJsWzafbKYh6OdjgvTspoH+UjaxgIXY=
|
||||
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250523195325-2d1af06b557c/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ=
|
||||
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 h1:X0tsNa2UHCKNw+illiavosasVzqioRo32SRV35iwr2I=
|
||||
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71/go.mod h1:EJWvaCrhOhNGVZMvcjc0yVryl4qqpMs8tz0r9WyEkdQ=
|
||||
github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
|
||||
github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
|
||||
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 h1:WkwO6Ks3mSIGnGuSdKl9qDSyfbYK50z2wc2gGMggegE=
|
||||
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706/go.mod h1:mjJGp00cxcfvD5xdCa+bso251Jt4owrQvuimJtVmEmM=
|
||||
github.com/charmbracelet/x/ansi v0.9.3-0.20250602153603-fb931ed90413 h1:L07QkDqRF274IZ2UJ/mCTL8DR95efU9BNWLYCDXEjvQ=
|
||||
|
||||
@@ -57,7 +57,8 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) {
|
||||
|
||||
cfg := config.Get()
|
||||
|
||||
coderAgentCfg := cfg.Agents[config.AgentCoder]
|
||||
// TODO: remove the concept of agent config most likely
|
||||
coderAgentCfg := cfg.Agents["coder"]
|
||||
if coderAgentCfg.ID == "" {
|
||||
return nil, fmt.Errorf("coder agent configuration is missing")
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
|
||||
defer cancel()
|
||||
|
||||
// Initialize with the initialization context
|
||||
_, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
|
||||
_, err = lspClient.InitializeLSPClient(initCtx, config.Get().WorkingDir())
|
||||
if err != nil {
|
||||
logging.Error("Initialize failed", "name", name, "error", err)
|
||||
// Clean up the client to prevent resource leaks
|
||||
@@ -91,7 +91,7 @@ func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceW
|
||||
app.restartLSPClient(ctx, name)
|
||||
})
|
||||
|
||||
workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
|
||||
workspaceWatcher.WatchWorkspace(ctx, config.Get().WorkingDir())
|
||||
logging.Info("Workspace watcher stopped", "client", name)
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,71 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var testConfigDir string
|
||||
|
||||
func baseConfigPath() string {
|
||||
if testConfigDir != "" {
|
||||
return testConfigDir
|
||||
}
|
||||
|
||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdgConfigHome != "" {
|
||||
return filepath.Join(xdgConfigHome, "crush")
|
||||
}
|
||||
|
||||
// return the path to the main config directory
|
||||
// for windows, it should be in `%LOCALAPPDATA%/crush/`
|
||||
// for linux and macOS, it should be in `$HOME/.config/crush/`
|
||||
if runtime.GOOS == "windows" {
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
|
||||
}
|
||||
return filepath.Join(localAppData, appName)
|
||||
}
|
||||
|
||||
return filepath.Join(os.Getenv("HOME"), ".config", appName)
|
||||
}
|
||||
|
||||
func baseDataPath() string {
|
||||
if testConfigDir != "" {
|
||||
return testConfigDir
|
||||
}
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
return filepath.Join(xdgDataHome, appName)
|
||||
}
|
||||
|
||||
// return the path to the main data directory
|
||||
// for windows, it should be in `%LOCALAPPDATA%/crush/`
|
||||
// for linux and macOS, it should be in `$HOME/.local/share/crush/`
|
||||
if runtime.GOOS == "windows" {
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
|
||||
}
|
||||
return filepath.Join(localAppData, appName)
|
||||
}
|
||||
|
||||
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName)
|
||||
}
|
||||
|
||||
func ConfigPath() string {
|
||||
return filepath.Join(baseConfigPath(), fmt.Sprintf("%s.json", appName))
|
||||
}
|
||||
|
||||
func CrushInitialized() bool {
|
||||
cfgPath := ConfigPath()
|
||||
if _, err := os.Stat(cfgPath); os.IsNotExist(err) {
|
||||
// config file does not exist, so Crush is not initialized
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -5,27 +5,53 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
// InitFlagFilename is the name of the file that indicates whether the project has been initialized
|
||||
InitFlagFilename = "init"
|
||||
)
|
||||
|
||||
// ProjectInitFlag represents the initialization status for a project directory
|
||||
type ProjectInitFlag struct {
|
||||
Initialized bool `json:"initialized"`
|
||||
}
|
||||
|
||||
// ProjectNeedsInitialization checks if the current project needs initialization
|
||||
// TODO: we need to remove the global config instance keeping it now just until everything is migrated
|
||||
var (
|
||||
instance atomic.Pointer[Config]
|
||||
cwd string
|
||||
once sync.Once // Ensures the initialization happens only once
|
||||
)
|
||||
|
||||
func Init(workingDir string, debug bool) (*Config, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
cwd = workingDir
|
||||
cfg, err := Load(cwd, debug)
|
||||
if err != nil {
|
||||
logging.Error("Failed to load config", "error", err)
|
||||
}
|
||||
instance.Store(cfg)
|
||||
})
|
||||
|
||||
return instance.Load(), err
|
||||
}
|
||||
|
||||
func Get() *Config {
|
||||
return instance.Load()
|
||||
}
|
||||
|
||||
func ProjectNeedsInitialization() (bool, error) {
|
||||
if instance == nil {
|
||||
cfg := Get()
|
||||
if cfg == nil {
|
||||
return false, fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
|
||||
flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
|
||||
|
||||
// Check if the flag file exists
|
||||
_, err := os.Stat(flagFilePath)
|
||||
if err == nil {
|
||||
return false, nil
|
||||
@@ -35,8 +61,7 @@ func ProjectNeedsInitialization() (bool, error) {
|
||||
return false, fmt.Errorf("failed to check init flag file: %w", err)
|
||||
}
|
||||
|
||||
// Check if any variation of CRUSH.md already exists in working directory
|
||||
crushExists, err := crushMdExists(WorkingDirectory())
|
||||
crushExists, err := crushMdExists(cfg.WorkingDir())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check for CRUSH.md files: %w", err)
|
||||
}
|
||||
@@ -47,7 +72,6 @@ func ProjectNeedsInitialization() (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// crushMdExists checks if any case variation of crush.md exists in the directory
|
||||
func crushMdExists(dir string) (bool, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
@@ -68,12 +92,12 @@ func crushMdExists(dir string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// MarkProjectInitialized marks the current project as initialized
|
||||
func MarkProjectInitialized() error {
|
||||
if instance == nil {
|
||||
cfg := Get()
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
flagFilePath := filepath.Join(instance.Options.DataDirectory, InitFlagFilename)
|
||||
flagFilePath := filepath.Join(cfg.Options.DataDirectory, InitFlagFilename)
|
||||
|
||||
file, err := os.Create(flagFilePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/charmbracelet/crush/internal/fur/client"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/pkg/env"
|
||||
"github.com/charmbracelet/crush/pkg/log"
|
||||
"github.com/charmbracelet/crush/internal/log"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
@@ -68,6 +68,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
env := env.New()
|
||||
// Configure providers
|
||||
valueResolver := NewShellVariableResolver(env)
|
||||
cfg.resolver = valueResolver
|
||||
if err := cfg.configureProviders(env, valueResolver, providers); err != nil {
|
||||
return nil, fmt.Errorf("failed to configure providers: %w", err)
|
||||
}
|
||||
@@ -81,6 +82,36 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to configure selected models: %w", err)
|
||||
}
|
||||
|
||||
// TODO: remove the agents concept from the config
|
||||
agents := map[string]Agent{
|
||||
"coder": {
|
||||
ID: "coder",
|
||||
Name: "Coder",
|
||||
Description: "An agent that helps with executing coding tasks.",
|
||||
Model: SelectedModelTypeLarge,
|
||||
ContextPaths: cfg.Options.ContextPaths,
|
||||
// All tools allowed
|
||||
},
|
||||
"task": {
|
||||
ID: "task",
|
||||
Name: "Task",
|
||||
Description: "An agent that helps with searching for context and finding implementation details.",
|
||||
Model: SelectedModelTypeLarge,
|
||||
ContextPaths: cfg.Options.ContextPaths,
|
||||
AllowedTools: []string{
|
||||
"glob",
|
||||
"grep",
|
||||
"ls",
|
||||
"sourcegraph",
|
||||
"view",
|
||||
},
|
||||
// NO MCPs or LSPs by default
|
||||
AllowedMCP: map[string][]string{},
|
||||
AllowedLSP: []string{},
|
||||
},
|
||||
}
|
||||
cfg.Agents = agents
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/pkg/env"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -4,27 +4,44 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/client"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
var fur = client.New()
|
||||
|
||||
var (
|
||||
providerOnc sync.Once // Ensures the initialization happens only once
|
||||
providerList []provider.Provider
|
||||
// UseMockProviders can be set to true in tests to avoid API calls
|
||||
UseMockProviders bool
|
||||
)
|
||||
|
||||
func providersPath() string {
|
||||
return filepath.Join(baseDataPath(), "providers.json")
|
||||
type ProviderClient interface {
|
||||
GetProviders() ([]provider.Provider, error)
|
||||
}
|
||||
|
||||
func saveProviders(providers []provider.Provider) error {
|
||||
path := providersPath()
|
||||
var (
|
||||
providerOnce sync.Once
|
||||
providerList []provider.Provider
|
||||
)
|
||||
|
||||
// file to cache provider data
|
||||
func providerCacheFileData() string {
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
return filepath.Join(xdgDataHome, appName)
|
||||
}
|
||||
|
||||
// return the path to the main data directory
|
||||
// for windows, it should be in `%LOCALAPPDATA%/crush/`
|
||||
// for linux and macOS, it should be in `$HOME/.local/share/crush/`
|
||||
if runtime.GOOS == "windows" {
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
|
||||
}
|
||||
return filepath.Join(localAppData, appName)
|
||||
}
|
||||
|
||||
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
|
||||
}
|
||||
|
||||
func saveProvidersInCache(path string, providers []provider.Provider) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
@@ -38,8 +55,7 @@ func saveProviders(providers []provider.Provider) error {
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func loadProviders() ([]provider.Provider, error) {
|
||||
path := providersPath()
|
||||
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -50,34 +66,33 @@ func loadProviders() ([]provider.Provider, error) {
|
||||
return providers, err
|
||||
}
|
||||
|
||||
func Providers() []provider.Provider {
|
||||
providerOnc.Do(func() {
|
||||
// Use mock providers when testing
|
||||
if UseMockProviders {
|
||||
providerList = MockProviders()
|
||||
return
|
||||
func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
|
||||
providers, err := client.GetProviders()
|
||||
if err != nil {
|
||||
fallbackToCache, err := loadProvidersFromCache(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providers = fallbackToCache
|
||||
} else {
|
||||
if err := saveProvidersInCache(path, providerList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// Try to get providers from upstream API
|
||||
if providers, err := fur.GetProviders(); err == nil {
|
||||
providerList = providers
|
||||
// Save providers locally for future fallback
|
||||
_ = saveProviders(providers)
|
||||
} else {
|
||||
// If upstream fails, try to load from local cache
|
||||
if localProviders, localErr := loadProviders(); localErr == nil {
|
||||
providerList = localProviders
|
||||
} else {
|
||||
// If both fail, return empty list
|
||||
providerList = []provider.Provider{}
|
||||
}
|
||||
}
|
||||
func Providers() ([]provider.Provider, error) {
|
||||
return LoadProviders(client.New())
|
||||
}
|
||||
|
||||
func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
|
||||
var err error
|
||||
providerOnce.Do(func() {
|
||||
providerList, err = loadProviders(providerCacheFileData(), client)
|
||||
})
|
||||
return providerList
|
||||
}
|
||||
|
||||
// ResetProviders resets the provider cache. Useful for testing.
|
||||
func ResetProviders() {
|
||||
providerOnc = sync.Once{}
|
||||
providerList = nil
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return providerList, nil
|
||||
}
|
||||
|
||||
@@ -1,293 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
// MockProviders returns a mock list of providers for testing.
|
||||
// This avoids making API calls during tests and provides consistent test data.
|
||||
// Simplified version with only default models from each provider.
|
||||
func MockProviders() []provider.Provider {
|
||||
return []provider.Provider{
|
||||
{
|
||||
Name: "Anthropic",
|
||||
ID: provider.InferenceProviderAnthropic,
|
||||
APIKey: "$ANTHROPIC_API_KEY",
|
||||
APIEndpoint: "$ANTHROPIC_API_ENDPOINT",
|
||||
Type: provider.TypeAnthropic,
|
||||
DefaultLargeModelID: "claude-sonnet-4-20250514",
|
||||
DefaultSmallModelID: "claude-3-5-haiku-20241022",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "claude-sonnet-4-20250514",
|
||||
Name: "Claude Sonnet 4",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MOut: 15.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.3,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsImages: true,
|
||||
},
|
||||
{
|
||||
ID: "claude-3-5-haiku-20241022",
|
||||
Name: "Claude 3.5 Haiku",
|
||||
CostPer1MIn: 0.8,
|
||||
CostPer1MOut: 4.0,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 5000,
|
||||
CanReason: false,
|
||||
SupportsImages: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OpenAI",
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "$OPENAI_API_ENDPOINT",
|
||||
Type: provider.TypeOpenAI,
|
||||
DefaultLargeModelID: "codex-mini-latest",
|
||||
DefaultSmallModelID: "gpt-4o",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "codex-mini-latest",
|
||||
Name: "Codex Mini",
|
||||
CostPer1MIn: 1.5,
|
||||
CostPer1MOut: 6.0,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.375,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
HasReasoningEffort: true,
|
||||
DefaultReasoningEffort: "medium",
|
||||
SupportsImages: true,
|
||||
},
|
||||
{
|
||||
ID: "gpt-4o",
|
||||
Name: "GPT-4o",
|
||||
CostPer1MIn: 2.5,
|
||||
CostPer1MOut: 10.0,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 1.25,
|
||||
ContextWindow: 128000,
|
||||
DefaultMaxTokens: 20000,
|
||||
CanReason: false,
|
||||
SupportsImages: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Google Gemini",
|
||||
ID: provider.InferenceProviderGemini,
|
||||
APIKey: "$GEMINI_API_KEY",
|
||||
APIEndpoint: "$GEMINI_API_ENDPOINT",
|
||||
Type: provider.TypeGemini,
|
||||
DefaultLargeModelID: "gemini-2.5-pro",
|
||||
DefaultSmallModelID: "gemini-2.5-flash",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
Name: "Gemini 2.5 Pro",
|
||||
CostPer1MIn: 1.25,
|
||||
CostPer1MOut: 10.0,
|
||||
CostPer1MInCached: 1.625,
|
||||
CostPer1MOutCached: 0.31,
|
||||
ContextWindow: 1048576,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsImages: true,
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-flash",
|
||||
Name: "Gemini 2.5 Flash",
|
||||
CostPer1MIn: 0.3,
|
||||
CostPer1MOut: 2.5,
|
||||
CostPer1MInCached: 0.3833,
|
||||
CostPer1MOutCached: 0.075,
|
||||
ContextWindow: 1048576,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsImages: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "xAI",
|
||||
ID: provider.InferenceProviderXAI,
|
||||
APIKey: "$XAI_API_KEY",
|
||||
APIEndpoint: "https://api.x.ai/v1",
|
||||
Type: provider.TypeXAI,
|
||||
DefaultLargeModelID: "grok-3",
|
||||
DefaultSmallModelID: "grok-3-mini",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "grok-3",
|
||||
Name: "Grok 3",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MOut: 15.0,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.75,
|
||||
ContextWindow: 131072,
|
||||
DefaultMaxTokens: 20000,
|
||||
CanReason: false,
|
||||
SupportsImages: false,
|
||||
},
|
||||
{
|
||||
ID: "grok-3-mini",
|
||||
Name: "Grok 3 Mini",
|
||||
CostPer1MIn: 0.3,
|
||||
CostPer1MOut: 0.5,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.075,
|
||||
ContextWindow: 131072,
|
||||
DefaultMaxTokens: 20000,
|
||||
CanReason: true,
|
||||
SupportsImages: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Azure OpenAI",
|
||||
ID: provider.InferenceProviderAzure,
|
||||
APIKey: "$AZURE_OPENAI_API_KEY",
|
||||
APIEndpoint: "$AZURE_OPENAI_API_ENDPOINT",
|
||||
Type: provider.TypeAzure,
|
||||
DefaultLargeModelID: "o4-mini",
|
||||
DefaultSmallModelID: "gpt-4o",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "o4-mini",
|
||||
Name: "o4 Mini",
|
||||
CostPer1MIn: 1.1,
|
||||
CostPer1MOut: 4.4,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.275,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
HasReasoningEffort: false,
|
||||
DefaultReasoningEffort: "medium",
|
||||
SupportsImages: true,
|
||||
},
|
||||
{
|
||||
ID: "gpt-4o",
|
||||
Name: "GPT-4o",
|
||||
CostPer1MIn: 2.5,
|
||||
CostPer1MOut: 10.0,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 1.25,
|
||||
ContextWindow: 128000,
|
||||
DefaultMaxTokens: 20000,
|
||||
CanReason: false,
|
||||
SupportsImages: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "AWS Bedrock",
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
Type: provider.TypeBedrock,
|
||||
DefaultLargeModelID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
DefaultSmallModelID: "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
Name: "AWS Claude Sonnet 4",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MOut: 15.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.3,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsImages: true,
|
||||
},
|
||||
{
|
||||
ID: "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
Name: "AWS Claude 3.5 Haiku",
|
||||
CostPer1MIn: 0.8,
|
||||
CostPer1MOut: 4.0,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: false,
|
||||
SupportsImages: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Google Vertex AI",
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
Type: provider.TypeVertexAI,
|
||||
DefaultLargeModelID: "gemini-2.5-pro",
|
||||
DefaultSmallModelID: "gemini-2.5-flash",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
Name: "Gemini 2.5 Pro",
|
||||
CostPer1MIn: 1.25,
|
||||
CostPer1MOut: 10.0,
|
||||
CostPer1MInCached: 1.625,
|
||||
CostPer1MOutCached: 0.31,
|
||||
ContextWindow: 1048576,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsImages: true,
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-flash",
|
||||
Name: "Gemini 2.5 Flash",
|
||||
CostPer1MIn: 0.3,
|
||||
CostPer1MOut: 2.5,
|
||||
CostPer1MInCached: 0.3833,
|
||||
CostPer1MOutCached: 0.075,
|
||||
ContextWindow: 1048576,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsImages: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OpenRouter",
|
||||
ID: provider.InferenceProviderOpenRouter,
|
||||
APIKey: "$OPENROUTER_API_KEY",
|
||||
APIEndpoint: "https://openrouter.ai/api/v1",
|
||||
Type: provider.TypeOpenAI,
|
||||
DefaultLargeModelID: "anthropic/claude-sonnet-4",
|
||||
DefaultSmallModelID: "anthropic/claude-haiku-3.5",
|
||||
Models: []provider.Model{
|
||||
{
|
||||
ID: "anthropic/claude-sonnet-4",
|
||||
Name: "Anthropic: Claude Sonnet 4",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MOut: 15.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.3,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 32000,
|
||||
CanReason: true,
|
||||
SupportsImages: true,
|
||||
},
|
||||
{
|
||||
ID: "anthropic/claude-haiku-3.5",
|
||||
Name: "Anthropic: Claude 3.5 Haiku",
|
||||
CostPer1MIn: 0.8,
|
||||
CostPer1MOut: 4.0,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
CanReason: false,
|
||||
SupportsImages: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,81 +1,73 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProviders_MockEnabled(t *testing.T) {
|
||||
originalUseMock := UseMockProviders
|
||||
UseMockProviders = true
|
||||
defer func() {
|
||||
UseMockProviders = originalUseMock
|
||||
ResetProviders()
|
||||
}()
|
||||
|
||||
ResetProviders()
|
||||
providers := Providers()
|
||||
require.NotEmpty(t, providers)
|
||||
|
||||
providerIDs := make(map[provider.InferenceProvider]bool)
|
||||
for _, p := range providers {
|
||||
providerIDs[p.ID] = true
|
||||
}
|
||||
|
||||
assert.True(t, providerIDs[provider.InferenceProviderAnthropic])
|
||||
assert.True(t, providerIDs[provider.InferenceProviderOpenAI])
|
||||
assert.True(t, providerIDs[provider.InferenceProviderGemini])
|
||||
type mockProviderClient struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
func TestProviders_ResetFunctionality(t *testing.T) {
|
||||
UseMockProviders = true
|
||||
defer func() {
|
||||
UseMockProviders = false
|
||||
ResetProviders()
|
||||
}()
|
||||
|
||||
providers1 := Providers()
|
||||
require.NotEmpty(t, providers1)
|
||||
|
||||
ResetProviders()
|
||||
providers2 := Providers()
|
||||
require.NotEmpty(t, providers2)
|
||||
|
||||
assert.Equal(t, len(providers1), len(providers2))
|
||||
func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
|
||||
if m.shouldFail {
|
||||
return nil, errors.New("failed to load providers")
|
||||
}
|
||||
return []provider.Provider{
|
||||
{
|
||||
Name: "Mock",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestProviders_ModelCapabilities(t *testing.T) {
|
||||
originalUseMock := UseMockProviders
|
||||
UseMockProviders = true
|
||||
defer func() {
|
||||
UseMockProviders = originalUseMock
|
||||
ResetProviders()
|
||||
}()
|
||||
func TestProvider_loadProvidersNoIssues(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: false}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
providers, err := loadProviders(tmpPath, client)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, providers)
|
||||
assert.Len(t, providers, 1)
|
||||
|
||||
ResetProviders()
|
||||
providers := Providers()
|
||||
|
||||
var openaiProvider provider.Provider
|
||||
for _, p := range providers {
|
||||
if p.ID == provider.InferenceProviderOpenAI {
|
||||
openaiProvider = p
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, openaiProvider.ID)
|
||||
|
||||
var foundReasoning, foundNonReasoning bool
|
||||
for _, model := range openaiProvider.Models {
|
||||
if model.CanReason && model.HasReasoningEffort {
|
||||
foundReasoning = true
|
||||
} else if !model.CanReason {
|
||||
foundNonReasoning = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundReasoning)
|
||||
assert.True(t, foundNonReasoning)
|
||||
// check if file got saved
|
||||
fileInfo, err := os.Stat(tmpPath)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
|
||||
}
|
||||
|
||||
func TestProvider_loadProvidersWithIssues(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: true}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
// store providers to a temporary file
|
||||
oldProviders := []provider.Provider{
|
||||
{
|
||||
Name: "OldProvider",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(oldProviders)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal old providers: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(tmpPath, data, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write old providers to file: %v", err)
|
||||
}
|
||||
providers, err := loadProviders(tmpPath, client)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, providers)
|
||||
assert.Len(t, providers, 1)
|
||||
assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
|
||||
}
|
||||
|
||||
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: true}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
providers, err := loadProviders(tmpPath, client)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/shell"
|
||||
"github.com/charmbracelet/crush/pkg/env"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
)
|
||||
|
||||
type VariableResolver interface {
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/pkg/env"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/shell"
|
||||
)
|
||||
|
||||
// ExecuteCommand executes a shell command and returns the output
|
||||
// This is a shared utility that can be used by both provider config and tools
|
||||
func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) {
|
||||
if workingDir == "" {
|
||||
workingDir = WorkingDirectory()
|
||||
}
|
||||
|
||||
persistentShell := shell.NewShell(&shell.Options{WorkingDir: workingDir})
|
||||
|
||||
stdout, stderr, err := persistentShell.Exec(ctx, command)
|
||||
if err != nil {
|
||||
logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr)
|
||||
return "", fmt.Errorf("command execution failed: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(stdout), nil
|
||||
}
|
||||
|
||||
// ResolveAPIKey resolves an API key that can be either:
|
||||
// - A direct string value
|
||||
// - An environment variable (prefixed with $)
|
||||
// - A shell command (wrapped in $(...))
|
||||
func ResolveAPIKey(apiKey string) (string, error) {
|
||||
if !strings.HasPrefix(apiKey, "$") {
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") {
|
||||
command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")")
|
||||
logging.Debug("Resolving API key from command", "command", command)
|
||||
return resolveCommandAPIKey(command)
|
||||
}
|
||||
|
||||
envVar := strings.TrimPrefix(apiKey, "$")
|
||||
if value := os.Getenv(envVar); value != "" {
|
||||
logging.Debug("Resolved environment variable", "envVar", envVar, "value", value)
|
||||
return value, nil
|
||||
}
|
||||
|
||||
logging.Debug("Environment variable not found", "envVar", envVar)
|
||||
|
||||
return "", fmt.Errorf("environment variable %s not found", envVar)
|
||||
}
|
||||
|
||||
// resolveCommandAPIKey executes a command to get an API key, with caching support
|
||||
func resolveCommandAPIKey(command string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
logging.Debug("Executing command for API key", "command", command)
|
||||
|
||||
workingDir := WorkingDirectory()
|
||||
|
||||
result, err := ExecuteCommand(ctx, command, workingDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute API key command: %w", err)
|
||||
}
|
||||
logging.Debug("Command executed successfully", "command", command, "result", result)
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,462 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfig_Validate_ValidConfig(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Models: PreferredModels{
|
||||
Large: PreferredModel{
|
||||
ModelID: "gpt-4",
|
||||
Provider: provider.InferenceProviderOpenAI,
|
||||
},
|
||||
Small: PreferredModel{
|
||||
ModelID: "gpt-3.5-turbo",
|
||||
Provider: provider.InferenceProviderOpenAI,
|
||||
},
|
||||
},
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
DefaultLargeModel: "gpt-4",
|
||||
DefaultSmallModel: "gpt-3.5-turbo",
|
||||
Models: []Model{
|
||||
{
|
||||
ID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
ContextWindow: 8192,
|
||||
DefaultMaxTokens: 4096,
|
||||
CostPer1MIn: 30.0,
|
||||
CostPer1MOut: 60.0,
|
||||
},
|
||||
{
|
||||
ID: "gpt-3.5-turbo",
|
||||
Name: "GPT-3.5 Turbo",
|
||||
ContextWindow: 4096,
|
||||
DefaultMaxTokens: 2048,
|
||||
CostPer1MIn: 1.5,
|
||||
CostPer1MOut: 2.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Agents: map[AgentID]Agent{
|
||||
AgentCoder: {
|
||||
ID: AgentCoder,
|
||||
Name: "Coder",
|
||||
Description: "An agent that helps with executing coding tasks.",
|
||||
Model: LargeModel,
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
AgentTask: {
|
||||
ID: AgentTask,
|
||||
Name: "Task",
|
||||
Description: "An agent that helps with searching for context and finding implementation details.",
|
||||
Model: LargeModel,
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
AllowedTools: []string{"glob", "grep", "ls", "sourcegraph", "view"},
|
||||
AllowedMCP: map[string][]string{},
|
||||
AllowedLSP: []string{},
|
||||
},
|
||||
},
|
||||
MCP: map[string]MCP{},
|
||||
LSP: map[string]LSPConfig{},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingAPIKey(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
// Missing APIKey
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "API key is required")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_InvalidProviderType(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.Type("invalid"),
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid provider type")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_CustomProviderMissingBaseURL(t *testing.T) {
|
||||
customProvider := provider.InferenceProvider("custom-provider")
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
customProvider: {
|
||||
ID: customProvider,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
// Missing BaseURL for custom provider
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "BaseURL is required for custom providers")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_DuplicateModelIDs(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
Models: []Model{
|
||||
{
|
||||
ID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
ContextWindow: 8192,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
{
|
||||
ID: "gpt-4", // Duplicate ID
|
||||
Name: "GPT-4 Duplicate",
|
||||
ContextWindow: 8192,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate model ID")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_InvalidModelFields(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
Models: []Model{
|
||||
{
|
||||
ID: "", // Empty ID
|
||||
Name: "GPT-4",
|
||||
ContextWindow: 0, // Invalid context window
|
||||
DefaultMaxTokens: -1, // Invalid max tokens
|
||||
CostPer1MIn: -5.0, // Negative cost
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
validationErr := err.(ValidationErrors)
|
||||
assert.True(t, len(validationErr) >= 4) // Should have multiple validation errors
|
||||
}
|
||||
|
||||
func TestConfig_Validate_DefaultModelNotFound(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
DefaultLargeModel: "nonexistent-model",
|
||||
Models: []Model{
|
||||
{
|
||||
ID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
ContextWindow: 8192,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "default large model 'nonexistent-model' not found")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_AgentIDMismatch(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Agents: map[AgentID]Agent{
|
||||
AgentCoder: {
|
||||
ID: AgentTask, // Wrong ID
|
||||
Name: "Coder",
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "agent ID mismatch")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_InvalidAgentModelType(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Agents: map[AgentID]Agent{
|
||||
AgentCoder: {
|
||||
ID: AgentCoder,
|
||||
Name: "Coder",
|
||||
Model: ModelType("invalid"),
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid model type")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_UnknownTool(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Agents: map[AgentID]Agent{
|
||||
AgentID("custom-agent"): {
|
||||
ID: AgentID("custom-agent"),
|
||||
Name: "Custom Agent",
|
||||
Model: LargeModel,
|
||||
AllowedTools: []string{"unknown-tool"},
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown tool")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MCPReference(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Agents: map[AgentID]Agent{
|
||||
AgentID("custom-agent"): {
|
||||
ID: AgentID("custom-agent"),
|
||||
Name: "Custom Agent",
|
||||
Model: LargeModel,
|
||||
AllowedMCP: map[string][]string{"nonexistent-mcp": nil},
|
||||
},
|
||||
},
|
||||
MCP: map[string]MCP{}, // Empty MCP map
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "referenced MCP 'nonexistent-mcp' not found")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_InvalidMCPType(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCP: map[string]MCP{
|
||||
"test-mcp": {
|
||||
Type: MCPType("invalid"),
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid MCP type")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MCPMissingCommand(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCP: map[string]MCP{
|
||||
"test-mcp": {
|
||||
Type: MCPStdio,
|
||||
// Missing Command
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "command is required for stdio MCP")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_LSPMissingCommand(t *testing.T) {
|
||||
cfg := &Config{
|
||||
LSP: map[string]LSPConfig{
|
||||
"test-lsp": {
|
||||
// Missing Command
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "command is required for LSP")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_NoValidProviders(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
Disabled: true, // Disabled
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "at least one non-disabled provider is required")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingDefaultAgents(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: map[provider.InferenceProvider]ProviderConfig{
|
||||
provider.InferenceProviderOpenAI: {
|
||||
ID: provider.InferenceProviderOpenAI,
|
||||
APIKey: "test-key",
|
||||
ProviderType: provider.TypeOpenAI,
|
||||
},
|
||||
},
|
||||
Agents: map[AgentID]Agent{}, // Missing default agents
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "coder agent is required")
|
||||
assert.Contains(t, err.Error(), "task agent is required")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_KnownAgentProtection(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Agents: map[AgentID]Agent{
|
||||
AgentCoder: {
|
||||
ID: AgentCoder,
|
||||
Name: "Modified Coder", // Should not be allowed
|
||||
Description: "Modified description", // Should not be allowed
|
||||
Model: LargeModel,
|
||||
},
|
||||
},
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "coder agent name cannot be changed")
|
||||
assert.Contains(t, err.Error(), "coder agent description cannot be changed")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_EmptyDataDirectory(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Options: Options{
|
||||
DataDirectory: "", // Empty
|
||||
ContextPaths: []string{"CRUSH.md"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "data directory is required")
|
||||
}
|
||||
|
||||
func TestConfig_Validate_EmptyContextPath(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Options: Options{
|
||||
DataDirectory: ".crush",
|
||||
ContextPaths: []string{""}, // Empty context path
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context path cannot be empty")
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) {
|
||||
// remove the cwd prefix and ensure consistent path format
|
||||
// this prevents issues with absolute paths in different environments
|
||||
cwd := config.WorkingDirectory()
|
||||
cwd := config.Get().WorkingDir()
|
||||
fileName = strings.TrimPrefix(fileName, cwd)
|
||||
fileName = strings.TrimPrefix(fileName, "/")
|
||||
|
||||
|
||||
0
pkg/env/env.go → internal/env/env.go
vendored
0
pkg/env/env.go → internal/env/env.go
vendored
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
fur "github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/llm/provider"
|
||||
@@ -49,7 +50,7 @@ type AgentEvent struct {
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[AgentEvent]
|
||||
Model() config.Model
|
||||
Model() fur.Model
|
||||
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
|
||||
Cancel(sessionID string)
|
||||
CancelAll()
|
||||
@@ -76,9 +77,9 @@ type agent struct {
|
||||
activeRequests sync.Map
|
||||
}
|
||||
|
||||
var agentPromptMap = map[config.AgentID]prompt.PromptID{
|
||||
config.AgentCoder: prompt.PromptCoder,
|
||||
config.AgentTask: prompt.PromptTask,
|
||||
var agentPromptMap = map[string]prompt.PromptID{
|
||||
"coder": prompt.PromptCoder,
|
||||
"task": prompt.PromptTask,
|
||||
}
|
||||
|
||||
func NewAgent(
|
||||
@@ -109,8 +110,8 @@ func NewAgent(
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
}
|
||||
|
||||
if agentCfg.ID == config.AgentCoder {
|
||||
taskAgentCfg := config.Get().Agents[config.AgentTask]
|
||||
if agentCfg.ID == "coder" {
|
||||
taskAgentCfg := config.Get().Agents["task"]
|
||||
if taskAgentCfg.ID == "" {
|
||||
return nil, fmt.Errorf("task agent not found in config")
|
||||
}
|
||||
@@ -130,13 +131,13 @@ func NewAgent(
|
||||
}
|
||||
|
||||
allTools = append(allTools, otherTools...)
|
||||
providerCfg := config.GetAgentProvider(agentCfg.ID)
|
||||
if providerCfg.ID == "" {
|
||||
providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
|
||||
if providerCfg == nil {
|
||||
return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
|
||||
}
|
||||
model := config.GetAgentModel(agentCfg.ID)
|
||||
model := config.Get().GetModelByType(agentCfg.Model)
|
||||
|
||||
if model.ID == "" {
|
||||
if model == nil {
|
||||
return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
|
||||
}
|
||||
|
||||
@@ -148,51 +149,40 @@ func NewAgent(
|
||||
provider.WithModel(agentCfg.Model),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
|
||||
}
|
||||
agentProvider, err := provider.NewProvider(providerCfg, opts...)
|
||||
agentProvider, err := provider.NewProvider(*providerCfg, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
smallModelCfg := cfg.Models.Small
|
||||
var smallModel config.Model
|
||||
|
||||
var smallModelProviderCfg config.ProviderConfig
|
||||
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
|
||||
var smallModelProviderCfg *config.ProviderConfig
|
||||
if smallModelCfg.Provider == providerCfg.ID {
|
||||
smallModelProviderCfg = providerCfg
|
||||
} else {
|
||||
for _, p := range cfg.Providers {
|
||||
if p.ID == smallModelCfg.Provider {
|
||||
smallModelProviderCfg = p
|
||||
break
|
||||
}
|
||||
}
|
||||
smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
|
||||
|
||||
if smallModelProviderCfg.ID == "" {
|
||||
return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
|
||||
}
|
||||
}
|
||||
for _, m := range smallModelProviderCfg.Models {
|
||||
if m.ID == smallModelCfg.ModelID {
|
||||
smallModel = m
|
||||
break
|
||||
}
|
||||
}
|
||||
smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
|
||||
if smallModel.ID == "" {
|
||||
return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
|
||||
return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
|
||||
}
|
||||
|
||||
titleOpts := []provider.ProviderClientOption{
|
||||
provider.WithModel(config.SmallModel),
|
||||
provider.WithModel(config.SelectedModelTypeSmall),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
|
||||
}
|
||||
titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
|
||||
titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summarizeOpts := []provider.ProviderClientOption{
|
||||
provider.WithModel(config.SmallModel),
|
||||
provider.WithModel(config.SelectedModelTypeSmall),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
|
||||
}
|
||||
summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
|
||||
summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -225,8 +215,8 @@ func NewAgent(
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (a *agent) Model() config.Model {
|
||||
return config.GetAgentModel(a.agentCfg.ID)
|
||||
func (a *agent) Model() fur.Model {
|
||||
return *config.Get().GetModelByType(a.agentCfg.Model)
|
||||
}
|
||||
|
||||
func (a *agent) Cancel(sessionID string) {
|
||||
@@ -610,7 +600,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
|
||||
sess, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
@@ -819,7 +809,7 @@ func (a *agent) UpdateModel() error {
|
||||
cfg := config.Get()
|
||||
|
||||
// Get current provider configuration
|
||||
currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
|
||||
currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
|
||||
if currentProviderCfg.ID == "" {
|
||||
return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
|
||||
}
|
||||
@@ -827,7 +817,7 @@ func (a *agent) UpdateModel() error {
|
||||
// Check if provider has changed
|
||||
if string(currentProviderCfg.ID) != a.providerID {
|
||||
// Provider changed, need to recreate the main provider
|
||||
model := config.GetAgentModel(a.agentCfg.ID)
|
||||
model := cfg.GetModelByType(a.agentCfg.Model)
|
||||
if model.ID == "" {
|
||||
return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
|
||||
}
|
||||
@@ -842,7 +832,7 @@ func (a *agent) UpdateModel() error {
|
||||
provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
|
||||
}
|
||||
|
||||
newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
|
||||
newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new provider: %w", err)
|
||||
}
|
||||
@@ -853,7 +843,7 @@ func (a *agent) UpdateModel() error {
|
||||
}
|
||||
|
||||
// Check if small model provider has changed (affects title and summarize providers)
|
||||
smallModelCfg := cfg.Models.Small
|
||||
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
|
||||
var smallModelProviderCfg config.ProviderConfig
|
||||
|
||||
for _, p := range cfg.Providers {
|
||||
@@ -869,20 +859,14 @@ func (a *agent) UpdateModel() error {
|
||||
|
||||
// Check if summarize provider has changed
|
||||
if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
|
||||
var smallModel config.Model
|
||||
for _, m := range smallModelProviderCfg.Models {
|
||||
if m.ID == smallModelCfg.ModelID {
|
||||
smallModel = m
|
||||
break
|
||||
}
|
||||
}
|
||||
if smallModel.ID == "" {
|
||||
return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
|
||||
smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
|
||||
if smallModel == nil {
|
||||
return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
|
||||
}
|
||||
|
||||
// Recreate title provider
|
||||
titleOpts := []provider.ProviderClientOption{
|
||||
provider.WithModel(config.SmallModel),
|
||||
provider.WithModel(config.SelectedModelTypeSmall),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
|
||||
// We want the title to be short, so we limit the max tokens
|
||||
provider.WithMaxTokens(40),
|
||||
@@ -894,7 +878,7 @@ func (a *agent) UpdateModel() error {
|
||||
|
||||
// Recreate summarize provider
|
||||
summarizeOpts := []provider.ProviderClientOption{
|
||||
provider.WithModel(config.SmallModel),
|
||||
provider.WithModel(config.SelectedModelTypeSmall),
|
||||
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
|
||||
}
|
||||
newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
type mcpTool struct {
|
||||
mcpName string
|
||||
tool mcp.Tool
|
||||
mcpConfig config.MCP
|
||||
mcpConfig config.MCPConfig
|
||||
permissions permission.Service
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
|
||||
p := b.permissions.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
Path: config.Get().WorkingDir(),
|
||||
ToolName: b.Info().Name,
|
||||
Action: "execute",
|
||||
Description: permissionDescription,
|
||||
@@ -142,7 +142,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
|
||||
return tools.NewTextErrorResponse("invalid mcp type"), nil
|
||||
}
|
||||
|
||||
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
|
||||
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool {
|
||||
return &mcpTool{
|
||||
mcpName: name,
|
||||
tool: tool,
|
||||
@@ -153,7 +153,7 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
|
||||
|
||||
var mcpTools []tools.BaseTool
|
||||
|
||||
func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
|
||||
func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
|
||||
var stdioTools []tools.BaseTool
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
|
||||
@@ -14,12 +14,12 @@ import (
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
)
|
||||
|
||||
func CoderPrompt(p provider.InferenceProvider, contextFiles ...string) string {
|
||||
func CoderPrompt(p string, contextFiles ...string) string {
|
||||
var basePrompt string
|
||||
switch p {
|
||||
case provider.InferenceProviderOpenAI:
|
||||
case string(provider.InferenceProviderOpenAI):
|
||||
basePrompt = baseOpenAICoderPrompt
|
||||
case provider.InferenceProviderGemini, provider.InferenceProviderVertexAI:
|
||||
case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI):
|
||||
basePrompt = baseGeminiCoderPrompt
|
||||
default:
|
||||
basePrompt = baseAnthropicCoderPrompt
|
||||
@@ -380,7 +380,7 @@ Your core function is efficient and safe assistance. Balance extreme conciseness
|
||||
`
|
||||
|
||||
func getEnvironmentInfo() string {
|
||||
cwd := config.WorkingDirectory()
|
||||
cwd := config.Get().WorkingDir()
|
||||
isGit := isGitRepo(cwd)
|
||||
platform := runtime.GOOS
|
||||
date := time.Now().Format("1/2/2006")
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
type PromptID string
|
||||
@@ -20,17 +19,17 @@ const (
|
||||
PromptDefault PromptID = "default"
|
||||
)
|
||||
|
||||
func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPaths ...string) string {
|
||||
func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string {
|
||||
basePrompt := ""
|
||||
switch promptID {
|
||||
case PromptCoder:
|
||||
basePrompt = CoderPrompt(provider)
|
||||
case PromptTitle:
|
||||
basePrompt = TitlePrompt(provider)
|
||||
basePrompt = TitlePrompt()
|
||||
case PromptTask:
|
||||
basePrompt = TaskPrompt(provider)
|
||||
basePrompt = TaskPrompt()
|
||||
case PromptSummarizer:
|
||||
basePrompt = SummarizerPrompt(provider)
|
||||
basePrompt = SummarizerPrompt()
|
||||
default:
|
||||
basePrompt = "You are a helpful assistant"
|
||||
}
|
||||
@@ -38,7 +37,7 @@ func GetPrompt(promptID PromptID, provider provider.InferenceProvider, contextPa
|
||||
}
|
||||
|
||||
func getContextFromPaths(contextPaths []string) string {
|
||||
return processContextPaths(config.WorkingDirectory(), contextPaths)
|
||||
return processContextPaths(config.Get().WorkingDir(), contextPaths)
|
||||
}
|
||||
|
||||
func processContextPaths(workDir string, paths []string) string {
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
func SummarizerPrompt(_ provider.InferenceProvider) string {
|
||||
func SummarizerPrompt() string {
|
||||
return `You are a helpful AI assistant tasked with summarizing conversations.
|
||||
|
||||
When asked to summarize, provide a detailed but concise summary of the conversation.
|
||||
|
||||
@@ -2,11 +2,9 @@ package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
func TaskPrompt(_ provider.InferenceProvider) string {
|
||||
func TaskPrompt() string {
|
||||
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
|
||||
Notes:
|
||||
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
func TitlePrompt(_ provider.InferenceProvider) string {
|
||||
func TitlePrompt() string {
|
||||
return `you will generate a short title based on the first message a user begins a conversation with
|
||||
- ensure it is not more than 50 characters long
|
||||
- the title should be a summary of the user's message
|
||||
|
||||
@@ -153,9 +153,9 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to
|
||||
model := a.providerOptions.model(a.providerOptions.modelType)
|
||||
var thinkingParam anthropic.ThinkingConfigParamUnion
|
||||
cfg := config.Get()
|
||||
modelConfig := cfg.Models.Large
|
||||
if a.providerOptions.modelType == config.SmallModel {
|
||||
modelConfig = cfg.Models.Small
|
||||
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
|
||||
if a.providerOptions.modelType == config.SelectedModelTypeSmall {
|
||||
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
|
||||
}
|
||||
temperature := anthropic.Float(0)
|
||||
|
||||
@@ -399,7 +399,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
|
||||
}
|
||||
|
||||
if apiErr.StatusCode == 401 {
|
||||
a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
|
||||
a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
|
||||
}
|
||||
@@ -490,6 +490,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) Model() config.Model {
|
||||
func (a *anthropicClient) Model() provider.Model {
|
||||
return a.providerOptions.model(a.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -31,14 +32,14 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
}
|
||||
}
|
||||
|
||||
opts.model = func(modelType config.ModelType) config.Model {
|
||||
model := config.GetModel(modelType)
|
||||
opts.model = func(modelType config.SelectedModelType) provider.Model {
|
||||
model := config.Get().GetModelByType(modelType)
|
||||
|
||||
// Prefix the model name with region
|
||||
regionPrefix := region[:2]
|
||||
modelName := model.ID
|
||||
model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
|
||||
return model
|
||||
return *model
|
||||
}
|
||||
|
||||
model := opts.model(opts.modelType)
|
||||
@@ -87,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
|
||||
return b.childProvider.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (b *bedrockClient) Model() config.Model {
|
||||
func (b *bedrockClient) Model() provider.Model {
|
||||
return b.providerOptions.model(b.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
@@ -170,9 +171,9 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
modelConfig := cfg.Models.Large
|
||||
if g.providerOptions.modelType == config.SmallModel {
|
||||
modelConfig = cfg.Models.Small
|
||||
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
|
||||
if g.providerOptions.modelType == config.SelectedModelTypeSmall {
|
||||
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
|
||||
}
|
||||
|
||||
maxTokens := model.DefaultMaxTokens
|
||||
@@ -268,9 +269,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
logging.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
modelConfig := cfg.Models.Large
|
||||
if g.providerOptions.modelType == config.SmallModel {
|
||||
modelConfig = cfg.Models.Small
|
||||
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
|
||||
if g.providerOptions.modelType == config.SelectedModelTypeSmall {
|
||||
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
|
||||
}
|
||||
maxTokens := model.DefaultMaxTokens
|
||||
if modelConfig.MaxTokens > 0 {
|
||||
@@ -424,7 +425,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
|
||||
|
||||
// Check for token expiration (401 Unauthorized)
|
||||
if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
|
||||
g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey)
|
||||
g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
|
||||
}
|
||||
@@ -462,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) Model() config.Model {
|
||||
func (g *geminiClient) Model() provider.Model {
|
||||
return g.providerOptions.model(g.providerOptions.modelType)
|
||||
}
|
||||
|
||||
|
||||
@@ -148,15 +148,12 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar
|
||||
model := o.providerOptions.model(o.providerOptions.modelType)
|
||||
cfg := config.Get()
|
||||
|
||||
modelConfig := cfg.Models.Large
|
||||
if o.providerOptions.modelType == config.SmallModel {
|
||||
modelConfig = cfg.Models.Small
|
||||
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
|
||||
if o.providerOptions.modelType == config.SelectedModelTypeSmall {
|
||||
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
|
||||
}
|
||||
|
||||
reasoningEffort := model.ReasoningEffort
|
||||
if modelConfig.ReasoningEffort != "" {
|
||||
reasoningEffort = modelConfig.ReasoningEffort
|
||||
}
|
||||
reasoningEffort := modelConfig.ReasoningEffort
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(model.ID),
|
||||
@@ -363,7 +360,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
|
||||
|
||||
// Check for token expiration (401 Unauthorized)
|
||||
if apiErr.StatusCode == 401 {
|
||||
o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey)
|
||||
o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
|
||||
if err != nil {
|
||||
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
|
||||
}
|
||||
@@ -420,6 +417,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *openaiClient) Model() config.Model {
|
||||
func (a *openaiClient) Model() provider.Model {
|
||||
return a.providerOptions.model(a.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -55,15 +55,15 @@ type Provider interface {
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() config.Model
|
||||
Model() provider.Model
|
||||
}
|
||||
|
||||
type providerClientOptions struct {
|
||||
baseURL string
|
||||
config config.ProviderConfig
|
||||
apiKey string
|
||||
modelType config.ModelType
|
||||
model func(config.ModelType) config.Model
|
||||
modelType config.SelectedModelType
|
||||
model func(config.SelectedModelType) provider.Model
|
||||
disableCache bool
|
||||
systemMessage string
|
||||
maxTokens int64
|
||||
@@ -77,7 +77,7 @@ type ProviderClient interface {
|
||||
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() config.Model
|
||||
Model() provider.Model
|
||||
}
|
||||
|
||||
type baseProvider[C ProviderClient] struct {
|
||||
@@ -106,11 +106,11 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
|
||||
return p.client.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) Model() config.Model {
|
||||
func (p *baseProvider[C]) Model() provider.Model {
|
||||
return p.client.Model()
|
||||
}
|
||||
|
||||
func WithModel(model config.ModelType) ProviderClientOption {
|
||||
func WithModel(model config.SelectedModelType) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.modelType = model
|
||||
}
|
||||
@@ -135,7 +135,7 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
|
||||
}
|
||||
|
||||
func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
|
||||
resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey)
|
||||
resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
|
||||
}
|
||||
@@ -145,14 +145,14 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
config: cfg,
|
||||
apiKey: resolvedAPIKey,
|
||||
extraHeaders: cfg.ExtraHeaders,
|
||||
model: func(tp config.ModelType) config.Model {
|
||||
return config.GetModel(tp)
|
||||
model: func(tp config.SelectedModelType) provider.Model {
|
||||
return *config.Get().GetModelByType(tp)
|
||||
},
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&clientOptions)
|
||||
}
|
||||
switch cfg.ProviderType {
|
||||
switch cfg.Type {
|
||||
case provider.TypeAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
@@ -190,5 +190,5 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
|
||||
return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
|
||||
}
|
||||
|
||||
@@ -317,7 +317,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
p := b.permissions.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
Path: config.Get().WorkingDir(),
|
||||
ToolName: BashToolName,
|
||||
Action: "execute",
|
||||
Description: fmt.Sprintf("Execute command: %s", params.Command),
|
||||
@@ -337,7 +337,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
defer cancel()
|
||||
}
|
||||
stdout, stderr, err := shell.
|
||||
GetPersistentShell(config.WorkingDirectory()).
|
||||
GetPersistentShell(config.Get().WorkingDir()).
|
||||
Exec(ctx, params.Command)
|
||||
interrupted := shell.IsInterrupt(err)
|
||||
exitCode := shell.ExitCode(err)
|
||||
|
||||
@@ -143,7 +143,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(params.FilePath) {
|
||||
wd := config.WorkingDirectory()
|
||||
wd := config.Get().WorkingDir()
|
||||
params.FilePath = filepath.Join(wd, params.FilePath)
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
|
||||
content,
|
||||
filePath,
|
||||
)
|
||||
rootDir := config.WorkingDirectory()
|
||||
rootDir := config.Get().WorkingDir()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
@@ -320,7 +320,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
filePath,
|
||||
)
|
||||
|
||||
rootDir := config.WorkingDirectory()
|
||||
rootDir := config.Get().WorkingDir()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
@@ -442,7 +442,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
newContent,
|
||||
filePath,
|
||||
)
|
||||
rootDir := config.WorkingDirectory()
|
||||
rootDir := config.Get().WorkingDir()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
|
||||
@@ -133,7 +133,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
p := t.permissions.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
Path: config.Get().WorkingDir(),
|
||||
ToolName: FetchToolName,
|
||||
Action: "fetch",
|
||||
Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
|
||||
|
||||
@@ -108,7 +108,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = config.WorkingDirectory()
|
||||
searchPath = config.Get().WorkingDir()
|
||||
}
|
||||
|
||||
files, truncated, err := globFiles(params.Pattern, searchPath, 100)
|
||||
|
||||
@@ -200,7 +200,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = config.WorkingDirectory()
|
||||
searchPath = config.Get().WorkingDir()
|
||||
}
|
||||
|
||||
matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100)
|
||||
|
||||
@@ -107,11 +107,11 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = config.WorkingDirectory()
|
||||
searchPath = config.Get().WorkingDir()
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(searchPath) {
|
||||
searchPath = filepath.Join(config.WorkingDirectory(), searchPath)
|
||||
searchPath = filepath.Join(config.Get().WorkingDir(), searchPath)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(searchPath); os.IsNotExist(err) {
|
||||
|
||||
@@ -117,7 +117,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
// Handle relative paths
|
||||
filePath := params.FilePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(config.WorkingDirectory(), filePath)
|
||||
filePath = filepath.Join(config.Get().WorkingDir(), filePath)
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
|
||||
@@ -122,7 +122,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
|
||||
filePath := params.FilePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(config.WorkingDirectory(), filePath)
|
||||
filePath = filepath.Join(config.Get().WorkingDir(), filePath)
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
@@ -170,7 +170,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
filePath,
|
||||
)
|
||||
|
||||
rootDir := config.WorkingDirectory()
|
||||
rootDir := config.Get().WorkingDir()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
|
||||
@@ -376,7 +376,7 @@ func (c *Client) detectServerType() ServerType {
|
||||
|
||||
// openKeyConfigFiles opens important configuration files that help initialize the server
|
||||
func (c *Client) openKeyConfigFiles(ctx context.Context) {
|
||||
workDir := config.WorkingDirectory()
|
||||
workDir := config.Get().WorkingDir()
|
||||
serverType := c.detectServerType()
|
||||
|
||||
var filesToOpen []string
|
||||
@@ -464,7 +464,7 @@ func (c *Client) pingTypeScriptServer(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// If we have no open TypeScript files, try to find and open one
|
||||
workDir := config.WorkingDirectory()
|
||||
workDir := config.Get().WorkingDir()
|
||||
err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -87,7 +87,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
|
||||
dir := filepath.Dir(opts.Path)
|
||||
if dir == "." {
|
||||
dir = config.WorkingDirectory()
|
||||
dir = config.Get().WorkingDir()
|
||||
}
|
||||
permission := PermissionRequest{
|
||||
ID: uuid.New().String(),
|
||||
|
||||
@@ -91,7 +91,7 @@ func (p *header) View() tea.View {
|
||||
|
||||
func (h *header) details() string {
|
||||
t := styles.CurrentTheme()
|
||||
cwd := fsext.DirTrim(fsext.PrettyPath(config.WorkingDirectory()), 4)
|
||||
cwd := fsext.DirTrim(fsext.PrettyPath(config.Get().WorkingDir()), 4)
|
||||
parts := []string{
|
||||
t.S().Muted.Render(cwd),
|
||||
}
|
||||
@@ -111,7 +111,8 @@ func (h *header) details() string {
|
||||
parts = append(parts, t.S().Error.Render(fmt.Sprintf("%s%d", styles.ErrorIcon, errorCount)))
|
||||
}
|
||||
|
||||
model := config.GetAgentModel(config.AgentCoder)
|
||||
agentCfg := config.Get().Agents["coder"]
|
||||
model := config.Get().GetModelByType(agentCfg.Model)
|
||||
percentage := (float64(h.session.CompletionTokens+h.session.PromptTokens) / float64(model.ContextWindow)) * 100
|
||||
formattedPercentage := t.S().Muted.Render(fmt.Sprintf("%d%%", int(percentage)))
|
||||
parts = append(parts, formattedPercentage)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/anim"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
@@ -296,7 +295,7 @@ func (m *assistantSectionModel) View() tea.View {
|
||||
duration := finishTime.Sub(m.lastUserMessageTime)
|
||||
infoMsg := t.S().Subtle.Render(duration.String())
|
||||
icon := t.S().Subtle.Render(styles.ModelIcon)
|
||||
model := config.GetProviderModel(provider.InferenceProvider(m.message.Provider), m.message.Model)
|
||||
model := config.Get().GetModel(m.message.Provider, m.message.Model)
|
||||
modelFormatted := t.S().Muted.Render(model.Name)
|
||||
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
|
||||
return tea.NewView(
|
||||
|
||||
@@ -297,7 +297,7 @@ func (m *sidebarCmp) filesBlock() string {
|
||||
}
|
||||
|
||||
extraContent := strings.Join(statusParts, " ")
|
||||
cwd := config.WorkingDirectory() + string(os.PathSeparator)
|
||||
cwd := config.Get().WorkingDir() + string(os.PathSeparator)
|
||||
filePath := file.FilePath
|
||||
filePath = strings.TrimPrefix(filePath, cwd)
|
||||
filePath = fsext.DirTrim(fsext.PrettyPath(filePath), 2)
|
||||
@@ -474,7 +474,8 @@ func formatTokensAndCost(tokens, contextWindow int64, cost float64) string {
|
||||
}
|
||||
|
||||
func (s *sidebarCmp) currentModelBlock() string {
|
||||
model := config.GetAgentModel(config.AgentCoder)
|
||||
agentCfg := config.Get().Agents["coder"]
|
||||
model := config.Get().GetModelByType(agentCfg.Model)
|
||||
|
||||
t := styles.CurrentTheme()
|
||||
|
||||
@@ -507,7 +508,7 @@ func (m *sidebarCmp) SetSession(session session.Session) tea.Cmd {
|
||||
}
|
||||
|
||||
func cwd() string {
|
||||
cwd := config.WorkingDirectory()
|
||||
cwd := config.Get().WorkingDir()
|
||||
t := styles.CurrentTheme()
|
||||
// Replace home directory with ~, unless we're at the top level of the
|
||||
// home directory).
|
||||
|
||||
@@ -31,8 +31,8 @@ const (
|
||||
|
||||
// ModelSelectedMsg is sent when a model is selected
|
||||
type ModelSelectedMsg struct {
|
||||
Model config.PreferredModel
|
||||
ModelType config.ModelType
|
||||
Model config.SelectedModel
|
||||
ModelType config.SelectedModelType
|
||||
}
|
||||
|
||||
// CloseModelDialogMsg is sent when a model is selected
|
||||
@@ -115,19 +115,19 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
items := m.modelList.Items()
|
||||
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
|
||||
|
||||
var modelType config.ModelType
|
||||
var modelType config.SelectedModelType
|
||||
if m.modelType == LargeModelType {
|
||||
modelType = config.LargeModel
|
||||
modelType = config.SelectedModelTypeLarge
|
||||
} else {
|
||||
modelType = config.SmallModel
|
||||
modelType = config.SelectedModelTypeSmall
|
||||
}
|
||||
|
||||
return m, tea.Sequence(
|
||||
util.CmdHandler(dialogs.CloseDialogMsg{}),
|
||||
util.CmdHandler(ModelSelectedMsg{
|
||||
Model: config.PreferredModel{
|
||||
ModelID: selectedItem.Model.ID,
|
||||
Provider: selectedItem.Provider.ID,
|
||||
Model: config.SelectedModel{
|
||||
Model: selectedItem.Model.ID,
|
||||
Provider: string(selectedItem.Provider.ID),
|
||||
},
|
||||
ModelType: modelType,
|
||||
}),
|
||||
@@ -218,35 +218,39 @@ func (m *modelDialogCmp) modelTypeRadio() string {
|
||||
func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
|
||||
m.modelType = modelType
|
||||
|
||||
providers := config.Providers()
|
||||
providers, err := config.Providers()
|
||||
if err != nil {
|
||||
return util.ReportError(err)
|
||||
}
|
||||
|
||||
modelItems := []util.Model{}
|
||||
selectIndex := 0
|
||||
|
||||
cfg := config.Get()
|
||||
var currentModel config.PreferredModel
|
||||
var currentModel config.SelectedModel
|
||||
if m.modelType == LargeModelType {
|
||||
currentModel = cfg.Models.Large
|
||||
currentModel = cfg.Models[config.SelectedModelTypeLarge]
|
||||
} else {
|
||||
currentModel = cfg.Models.Small
|
||||
currentModel = cfg.Models[config.SelectedModelTypeSmall]
|
||||
}
|
||||
|
||||
// Create a map to track which providers we've already added
|
||||
addedProviders := make(map[provider.InferenceProvider]bool)
|
||||
addedProviders := make(map[string]bool)
|
||||
|
||||
// First, add any configured providers that are not in the known providers list
|
||||
// These should appear at the top of the list
|
||||
knownProviders := provider.KnownProviders()
|
||||
for providerID, providerConfig := range cfg.Providers {
|
||||
if providerConfig.Disabled {
|
||||
if providerConfig.Disable {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this provider is not in the known providers list
|
||||
if !slices.Contains(knownProviders, providerID) {
|
||||
if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
|
||||
// Convert config provider to provider.Provider format
|
||||
configProvider := provider.Provider{
|
||||
Name: string(providerID), // Use provider ID as name for unknown providers
|
||||
ID: providerID,
|
||||
ID: provider.InferenceProvider(providerID),
|
||||
Models: make([]provider.Model, len(providerConfig.Models)),
|
||||
}
|
||||
|
||||
@@ -263,7 +267,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
|
||||
DefaultMaxTokens: model.DefaultMaxTokens,
|
||||
CanReason: model.CanReason,
|
||||
HasReasoningEffort: model.HasReasoningEffort,
|
||||
DefaultReasoningEffort: model.ReasoningEffort,
|
||||
DefaultReasoningEffort: model.DefaultReasoningEffort,
|
||||
SupportsImages: model.SupportsImages,
|
||||
}
|
||||
}
|
||||
@@ -279,7 +283,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
|
||||
Provider: configProvider,
|
||||
Model: model,
|
||||
}))
|
||||
if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
|
||||
if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
|
||||
selectIndex = len(modelItems) - 1 // Set the selected index to the current model
|
||||
}
|
||||
}
|
||||
@@ -290,12 +294,12 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
|
||||
// Then add the known providers from the predefined list
|
||||
for _, provider := range providers {
|
||||
// Skip if we already added this provider as an unknown provider
|
||||
if addedProviders[provider.ID] {
|
||||
if addedProviders[string(provider.ID)] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this provider is configured and not disabled
|
||||
if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
|
||||
if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -309,7 +313,7 @@ func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}))
|
||||
if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
|
||||
if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
|
||||
selectIndex = len(modelItems) - 1 // Set the selected index to the current model
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +170,8 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
util.CmdHandler(ChatFocusedMsg{Focused: false}),
|
||||
)
|
||||
case key.Matches(msg, p.keyMap.AddAttachment):
|
||||
model := config.GetAgentModel(config.AgentCoder)
|
||||
agentCfg := config.Get().Agents["coder"]
|
||||
model := config.Get().GetModelByType(agentCfg.Model)
|
||||
if model.SupportsImages {
|
||||
return p, util.CmdHandler(OpenFilePickerMsg{})
|
||||
} else {
|
||||
|
||||
@@ -177,14 +177,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Update the agent with the new model/provider configuration
|
||||
if err := a.app.UpdateAgentModel(); err != nil {
|
||||
logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err))
|
||||
return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err))
|
||||
return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.Model, err))
|
||||
}
|
||||
|
||||
modelTypeName := "large"
|
||||
if msg.ModelType == config.SmallModel {
|
||||
if msg.ModelType == config.SelectedModelTypeSmall {
|
||||
modelTypeName = "small"
|
||||
}
|
||||
return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.ModelID))
|
||||
return a, util.ReportInfo(fmt.Sprintf("%s model changed to %s", modelTypeName, msg.Model.Model))
|
||||
|
||||
// File Picker
|
||||
case chat.OpenFilePickerMsg:
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
const (
|
||||
appName = "crush"
|
||||
defaultDataDirectory = ".crush"
|
||||
defaultLogLevel = "info"
|
||||
)
|
||||
|
||||
var defaultContextPaths = []string{
|
||||
".github/copilot-instructions.md",
|
||||
".cursorrules",
|
||||
".cursor/rules/",
|
||||
"CLAUDE.md",
|
||||
"CLAUDE.local.md",
|
||||
"GEMINI.md",
|
||||
"gemini.md",
|
||||
"crush.md",
|
||||
"crush.local.md",
|
||||
"Crush.md",
|
||||
"Crush.local.md",
|
||||
"CRUSH.md",
|
||||
"CRUSH.local.md",
|
||||
}
|
||||
|
||||
type SelectedModelType string
|
||||
|
||||
const (
|
||||
SelectedModelTypeLarge SelectedModelType = "large"
|
||||
SelectedModelTypeSmall SelectedModelType = "small"
|
||||
)
|
||||
|
||||
type SelectedModel struct {
|
||||
// The model id as used by the provider API.
|
||||
// Required.
|
||||
Model string `json:"model"`
|
||||
// The model provider, same as the key/id used in the providers config.
|
||||
// Required.
|
||||
Provider string `json:"provider"`
|
||||
|
||||
// Only used by models that use the openai provider and need this set.
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
|
||||
// Overrides the default model configuration.
|
||||
MaxTokens int64 `json:"max_tokens,omitempty"`
|
||||
|
||||
// Used by anthropic models that can reason to indicate if the model should think.
|
||||
Think bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
// The provider's id.
|
||||
ID string `json:"id,omitempty"`
|
||||
// The provider's API endpoint.
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
|
||||
Type provider.Type `json:"type,omitempty"`
|
||||
// The provider's API key.
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
// Marks the provider as disabled.
|
||||
Disable bool `json:"disable,omitempty"`
|
||||
|
||||
// Extra headers to send with each request to the provider.
|
||||
ExtraHeaders map[string]string
|
||||
|
||||
// Used to pass extra parameters to the provider.
|
||||
ExtraParams map[string]string `json:"-"`
|
||||
|
||||
// The provider models
|
||||
Models []provider.Model `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
type MCPType string
|
||||
|
||||
const (
|
||||
MCPStdio MCPType = "stdio"
|
||||
MCPSse MCPType = "sse"
|
||||
MCPHttp MCPType = "http"
|
||||
)
|
||||
|
||||
type MCPConfig struct {
|
||||
Command string `json:"command,omitempty" `
|
||||
Env []string `json:"env,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Type MCPType `json:"type"`
|
||||
URL string `json:"url,omitempty"`
|
||||
|
||||
// TODO: maybe make it possible to get the value from the env
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
type LSPConfig struct {
|
||||
Disabled bool `json:"enabled,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Options any `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type TUIOptions struct {
|
||||
CompactMode bool `json:"compact_mode,omitempty"`
|
||||
// Here we can add themes later or any TUI related options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
ContextPaths []string `json:"context_paths,omitempty"`
|
||||
TUI *TUIOptions `json:"tui,omitempty"`
|
||||
Debug bool `json:"debug,omitempty"`
|
||||
DebugLSP bool `json:"debug_lsp,omitempty"`
|
||||
DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
|
||||
// Relative to the cwd
|
||||
DataDirectory string `json:"data_directory,omitempty"`
|
||||
}
|
||||
|
||||
type MCPs map[string]MCPConfig
|
||||
|
||||
type MCP struct {
|
||||
Name string `json:"name"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
func (m MCPs) Sorted() []MCP {
|
||||
sorted := make([]MCP, 0, len(m))
|
||||
for k, v := range m {
|
||||
sorted = append(sorted, MCP{
|
||||
Name: k,
|
||||
MCP: v,
|
||||
})
|
||||
}
|
||||
slices.SortFunc(sorted, func(a, b MCP) int {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
return sorted
|
||||
}
|
||||
|
||||
type LSPs map[string]LSPConfig
|
||||
|
||||
type LSP struct {
|
||||
Name string `json:"name"`
|
||||
LSP LSPConfig `json:"lsp"`
|
||||
}
|
||||
|
||||
func (l LSPs) Sorted() []LSP {
|
||||
sorted := make([]LSP, 0, len(l))
|
||||
for k, v := range l {
|
||||
sorted = append(sorted, LSP{
|
||||
Name: k,
|
||||
LSP: v,
|
||||
})
|
||||
}
|
||||
slices.SortFunc(sorted, func(a, b LSP) int {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
return sorted
|
||||
}
|
||||
|
||||
// Config holds the configuration for crush.
|
||||
type Config struct {
|
||||
// We currently only support large/small as values here.
|
||||
Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
|
||||
|
||||
// The providers that are configured
|
||||
Providers map[string]ProviderConfig `json:"providers,omitempty"`
|
||||
|
||||
MCP MCPs `json:"mcp,omitempty"`
|
||||
|
||||
LSP LSPs `json:"lsp,omitempty"`
|
||||
|
||||
Options *Options `json:"options,omitempty"`
|
||||
|
||||
// Internal
|
||||
workingDir string `json:"-"`
|
||||
}
|
||||
|
||||
func (c *Config) WorkingDir() string {
|
||||
return c.workingDir
|
||||
}
|
||||
|
||||
func (c *Config) EnabledProviders() []ProviderConfig {
|
||||
enabled := make([]ProviderConfig, 0, len(c.Providers))
|
||||
for _, p := range c.Providers {
|
||||
if !p.Disable {
|
||||
enabled = append(enabled, p)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
// IsConfigured return true if at least one provider is configured
|
||||
func (c *Config) IsConfigured() bool {
|
||||
return len(c.EnabledProviders()) > 0
|
||||
}
|
||||
|
||||
func (c *Config) GetModel(provider, model string) *provider.Model {
|
||||
if providerConfig, ok := c.Providers[provider]; ok {
|
||||
for _, m := range providerConfig.Models {
|
||||
if m.ID == model {
|
||||
return &m
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) LargeModel() *provider.Model {
|
||||
model, ok := c.Models[SelectedModelTypeLarge]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return c.GetModel(model.Provider, model.Model)
|
||||
}
|
||||
|
||||
func (c *Config) SmallModel() *provider.Model {
|
||||
model, ok := c.Models[SelectedModelTypeSmall]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return c.GetModel(model.Provider, model.Model)
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
type ProviderClient interface {
|
||||
GetProviders() ([]provider.Provider, error)
|
||||
}
|
||||
|
||||
var (
|
||||
providerOnce sync.Once
|
||||
providerList []provider.Provider
|
||||
)
|
||||
|
||||
// file to cache provider data
|
||||
func providerCacheFileData() string {
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
return filepath.Join(xdgDataHome, appName)
|
||||
}
|
||||
|
||||
// return the path to the main data directory
|
||||
// for windows, it should be in `%LOCALAPPDATA%/crush/`
|
||||
// for linux and macOS, it should be in `$HOME/.local/share/crush/`
|
||||
if runtime.GOOS == "windows" {
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local")
|
||||
}
|
||||
return filepath.Join(localAppData, appName)
|
||||
}
|
||||
|
||||
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
|
||||
}
|
||||
|
||||
func saveProvidersInCache(path string, providers []provider.Provider) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(providers, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var providers []provider.Provider
|
||||
err = json.Unmarshal(data, &providers)
|
||||
return providers, err
|
||||
}
|
||||
|
||||
func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
|
||||
providers, err := client.GetProviders()
|
||||
if err != nil {
|
||||
fallbackToCache, err := loadProvidersFromCache(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providers = fallbackToCache
|
||||
} else {
|
||||
if err := saveProvidersInCache(path, providerList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
|
||||
var err error
|
||||
providerOnce.Do(func() {
|
||||
providerList, err = loadProviders(providerCacheFileData(), client)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return providerList, nil
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockProviderClient struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
|
||||
if m.shouldFail {
|
||||
return nil, errors.New("failed to load providers")
|
||||
}
|
||||
return []provider.Provider{
|
||||
{
|
||||
Name: "Mock",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestProvider_loadProvidersNoIssues(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: false}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
providers, err := loadProviders(tmpPath, client)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, providers)
|
||||
assert.Len(t, providers, 1)
|
||||
|
||||
// check if file got saved
|
||||
fileInfo, err := os.Stat(tmpPath)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
|
||||
}
|
||||
|
||||
func TestProvider_loadProvidersWithIssues(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: true}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
// store providers to a temporary file
|
||||
oldProviders := []provider.Provider{
|
||||
{
|
||||
Name: "OldProvider",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(oldProviders)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal old providers: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(tmpPath, data, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write old providers to file: %v", err)
|
||||
}
|
||||
providers, err := loadProviders(tmpPath, client)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, providers)
|
||||
assert.Len(t, providers, 1)
|
||||
assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
|
||||
}
|
||||
|
||||
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: true}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
providers, err := loadProviders(tmpPath, client)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
|
||||
}
|
||||
Reference in New Issue
Block a user