Merge pull request #292 from charmbracelet/catwalk

feat: use new catwalk
This commit is contained in:
Kujtim Hoxha
2025-07-24 09:40:48 +02:00
committed by GitHub
26 changed files with 200 additions and 337 deletions

View File

@@ -9,8 +9,8 @@ import (
"strings"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/tidwall/sjson"
"golang.org/x/exp/slog"
)
@@ -70,7 +70,7 @@ type ProviderConfig struct {
// The provider's API endpoint.
BaseURL string `json:"base_url,omitempty"`
// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
Type provider.Type `json:"type,omitempty"`
Type catwalk.Type `json:"type,omitempty"`
// The provider's API key.
APIKey string `json:"api_key,omitempty"`
// Marks the provider as disabled.
@@ -85,7 +85,7 @@ type ProviderConfig struct {
ExtraParams map[string]string `json:"-"`
// The provider models
Models []provider.Model `json:"models,omitempty"`
Models []catwalk.Model `json:"models,omitempty"`
}
type MCPType string
@@ -250,8 +250,8 @@ type Config struct {
Agents map[string]Agent `json:"-"`
// TODO: find a better way to do this this should probably not be part of the config
resolver VariableResolver
dataConfigDir string `json:"-"`
knownProviders []provider.Provider `json:"-"`
dataConfigDir string `json:"-"`
knownProviders []catwalk.Provider `json:"-"`
}
func (c *Config) WorkingDir() string {
@@ -273,7 +273,7 @@ func (c *Config) IsConfigured() bool {
return len(c.EnabledProviders()) > 0
}
func (c *Config) GetModel(provider, model string) *provider.Model {
func (c *Config) GetModel(provider, model string) *catwalk.Model {
if providerConfig, ok := c.Providers[provider]; ok {
for _, m := range providerConfig.Models {
if m.ID == model {
@@ -295,7 +295,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi
return nil
}
func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
model, ok := c.Models[modelType]
if !ok {
return nil
@@ -303,7 +303,7 @@ func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
return c.GetModel(model.Provider, model.Model)
}
func (c *Config) LargeModel() *provider.Model {
func (c *Config) LargeModel() *catwalk.Model {
model, ok := c.Models[SelectedModelTypeLarge]
if !ok {
return nil
@@ -311,7 +311,7 @@ func (c *Config) LargeModel() *provider.Model {
return c.GetModel(model.Provider, model.Model)
}
func (c *Config) SmallModel() *provider.Model {
func (c *Config) SmallModel() *catwalk.Model {
model, ok := c.Models[SelectedModelTypeSmall]
if !ok {
return nil
@@ -381,7 +381,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
return nil
}
var foundProvider *provider.Provider
var foundProvider *catwalk.Provider
for _, p := range c.knownProviders {
if string(p.ID) == providerID {
foundProvider = &p
@@ -450,14 +450,14 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
headers := make(map[string]string)
apiKey, _ := resolver.ResolveValue(c.APIKey)
switch c.Type {
case provider.TypeOpenAI:
case catwalk.TypeOpenAI:
baseURL, _ := resolver.ResolveValue(c.BaseURL)
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
testURL = baseURL + "/models"
headers["Authorization"] = "Bearer " + apiKey
case provider.TypeAnthropic:
case catwalk.TypeAnthropic:
baseURL, _ := resolver.ResolveValue(c.BaseURL)
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"

View File

@@ -11,13 +11,14 @@ import (
"strings"
"sync"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/log"
"golang.org/x/exp/slog"
)
const catwalkURL = "https://catwalk.charm.sh"
// LoadReader config via io.Reader.
func LoadReader(fd io.Reader) (*Config, error) {
data, err := io.ReadAll(fd)
@@ -61,8 +62,8 @@ func Load(workingDir string, debug bool) (*Config, error) {
cfg.Options.Debug,
)
// Load known providers, this loads the config from fur
providers, err := LoadProviders(client.New())
// Load known providers, this loads the config from catwalk
providers, err := LoadProviders(catwalk.NewWithURL(catwalkURL))
if err != nil || len(providers) == 0 {
return nil, fmt.Errorf("failed to load providers: %w", err)
}
@@ -81,7 +82,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
var wg sync.WaitGroup
for _, p := range cfg.Providers {
if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
if p.Type == catwalk.TypeOpenAI || p.Type == catwalk.TypeAnthropic {
wg.Add(1)
go func(provider ProviderConfig) {
defer wg.Done()
@@ -117,7 +118,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
return cfg, nil
}
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
knownProviderNames := make(map[string]bool)
for _, p := range knownProviders {
knownProviderNames[string(p.ID)] = true
@@ -136,7 +137,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
p.APIKey = config.APIKey
}
if len(config.Models) > 0 {
models := []provider.Model{}
models := []catwalk.Model{}
seen := make(map[string]bool)
for _, model := range config.Models {
@@ -144,8 +145,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
continue
}
seen[model.ID] = true
if model.Model == "" {
model.Model = model.ID
if model.Name == "" {
model.Name = model.ID
}
models = append(models, model)
}
@@ -154,8 +155,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
continue
}
seen[model.ID] = true
if model.Model == "" {
model.Model = model.ID
if model.Name == "" {
model.Name = model.ID
}
models = append(models, model)
}
@@ -178,7 +179,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
switch p.ID {
// Handle specific providers that require additional configuration
case provider.InferenceProviderVertexAI:
case catwalk.InferenceProviderVertexAI:
if !hasVertexCredentials(env) {
if configExists {
slog.Warn("Skipping Vertex AI provider due to missing credentials")
@@ -188,7 +189,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
}
prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
case provider.InferenceProviderAzure:
case catwalk.InferenceProviderAzure:
endpoint, err := resolver.ResolveValue(p.APIEndpoint)
if err != nil || endpoint == "" {
if configExists {
@@ -199,7 +200,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
}
prepared.BaseURL = endpoint
prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION")
case provider.InferenceProviderBedrock:
case catwalk.InferenceProviderBedrock:
if !hasAWSCredentials(env) {
if configExists {
slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
@@ -239,7 +240,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
}
// default to OpenAI if not set
if providerConfig.Type == "" {
providerConfig.Type = provider.TypeOpenAI
providerConfig.Type = catwalk.TypeOpenAI
}
if providerConfig.Disable {
@@ -260,7 +261,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
delete(c.Providers, id)
continue
}
if providerConfig.Type != provider.TypeOpenAI {
if providerConfig.Type != catwalk.TypeOpenAI {
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
delete(c.Providers, id)
continue
@@ -315,7 +316,7 @@ func (c *Config) setDefaults(workingDir string) {
c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
}
func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
if len(knownProviders) == 0 && len(c.Providers) == 0 {
err = fmt.Errorf("no providers configured, please configure at least one provider")
return
@@ -384,7 +385,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg
return
}
func (c *Config) configureSelectedModels(knownProviders []provider.Provider) error {
func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
if err != nil {
return fmt.Errorf("failed to select default models: %w", err)

View File

@@ -8,8 +8,8 @@ import (
"strings"
"testing"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/stretchr/testify/assert"
)
@@ -54,12 +54,12 @@ func TestConfig_setDefaults(t *testing.T) {
}
func TestConfig_configureProviders(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$OPENAI_API_KEY",
APIEndpoint: "https://api.openai.com/v1",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -80,12 +80,12 @@ func TestConfig_configureProviders(t *testing.T) {
}
func TestConfig_configureProvidersWithOverride(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$OPENAI_API_KEY",
APIEndpoint: "https://api.openai.com/v1",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -96,10 +96,10 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
"openai": {
APIKey: "xyz",
BaseURL: "https://api.openai.com/v2",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "test-model",
Model: "Updated",
ID: "test-model",
Name: "Updated",
},
{
ID: "another-model",
@@ -122,16 +122,16 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
assert.Len(t, cfg.Providers["openai"].Models, 2)
assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model)
assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name)
}
func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$OPENAI_API_KEY",
APIEndpoint: "https://api.openai.com/v1",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -142,7 +142,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
"custom": {
APIKey: "xyz",
BaseURL: "https://api.someendpoint.com/v2",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "test-model",
},
@@ -172,12 +172,12 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
}
func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderBedrock,
ID: catwalk.InferenceProviderBedrock,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
}},
},
@@ -201,12 +201,12 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
}
func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderBedrock,
ID: catwalk.InferenceProviderBedrock,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
}},
},
@@ -223,12 +223,12 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
}
func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderBedrock,
ID: catwalk.InferenceProviderBedrock,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "some-random-model",
}},
},
@@ -246,12 +246,12 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
}
func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderVertexAI,
ID: catwalk.InferenceProviderVertexAI,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "gemini-pro",
}},
},
@@ -278,12 +278,12 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
}
func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderVertexAI,
ID: catwalk.InferenceProviderVertexAI,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "gemini-pro",
}},
},
@@ -304,12 +304,12 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
}
func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderVertexAI,
ID: catwalk.InferenceProviderVertexAI,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "gemini-pro",
}},
},
@@ -329,12 +329,12 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
}
func TestConfig_configureProvidersSetProviderID(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$OPENAI_API_KEY",
APIEndpoint: "https://api.openai.com/v1",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -450,12 +450,12 @@ func TestConfig_IsConfigured(t *testing.T) {
}
func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$OPENAI_API_KEY",
APIEndpoint: "https://api.openai.com/v1",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -489,7 +489,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
Providers: map[string]ProviderConfig{
"custom": {
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -502,7 +502,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []provider.Provider{})
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
assert.Len(t, cfg.Providers, 1)
@@ -515,7 +515,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
Providers: map[string]ProviderConfig{
"custom": {
APIKey: "test-key",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -525,7 +525,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []provider.Provider{})
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
assert.Len(t, cfg.Providers, 0)
@@ -539,7 +539,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{},
Models: []catwalk.Model{},
},
},
}
@@ -547,7 +547,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []provider.Provider{})
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
assert.Len(t, cfg.Providers, 0)
@@ -562,7 +562,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Type: "unsupported",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -572,7 +572,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []provider.Provider{})
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
assert.Len(t, cfg.Providers, 0)
@@ -586,8 +586,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Type: provider.TypeOpenAI,
Models: []provider.Model{{
Type: catwalk.TypeOpenAI,
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -597,7 +597,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []provider.Provider{})
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
assert.Len(t, cfg.Providers, 1)
@@ -614,9 +614,9 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Type: provider.TypeOpenAI,
Type: catwalk.TypeOpenAI,
Disable: true,
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -626,7 +626,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []provider.Provider{})
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
assert.Len(t, cfg.Providers, 0)
@@ -637,12 +637,12 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderVertexAI,
ID: catwalk.InferenceProviderVertexAI,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "gemini-pro",
}},
},
@@ -670,12 +670,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
})
t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: provider.InferenceProviderBedrock,
ID: catwalk.InferenceProviderBedrock,
APIKey: "",
APIEndpoint: "",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
}},
},
@@ -701,12 +701,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
})
t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$MISSING_API_KEY",
APIEndpoint: "https://api.openai.com/v1",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -732,12 +732,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
})
t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$OPENAI_API_KEY",
APIEndpoint: "$MISSING_ENDPOINT",
Models: []provider.Model{{
Models: []catwalk.Model{{
ID: "test-model",
}},
},
@@ -767,13 +767,13 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
func TestConfig_defaultModelSelection(t *testing.T) {
t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "abc",
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "large-model",
DefaultMaxTokens: 1000,
@@ -803,13 +803,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
assert.Equal(t, int64(500), small.MaxTokens)
})
t.Run("should error if no providers configured", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$MISSING_KEY",
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "large-model",
DefaultMaxTokens: 1000,
@@ -833,13 +833,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
assert.Error(t, err)
})
t.Run("should error if model is missing", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "abc",
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "not-large-model",
DefaultMaxTokens: 1000,
@@ -863,13 +863,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
})
t.Run("should configure the default models with a custom provider", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$MISSING", // will not be included in the config
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "not-large-model",
DefaultMaxTokens: 1000,
@@ -887,7 +887,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "model",
DefaultMaxTokens: 600,
@@ -912,13 +912,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
})
t.Run("should fail if no model configured", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "$MISSING", // will not be included in the config
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "not-large-model",
DefaultMaxTokens: 1000,
@@ -936,7 +936,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{},
Models: []catwalk.Model{},
},
},
}
@@ -949,13 +949,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
assert.Error(t, err)
})
t.Run("should use the default provider first", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "set",
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "large-model",
DefaultMaxTokens: 1000,
@@ -973,7 +973,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
"custom": {
APIKey: "test-key",
BaseURL: "https://api.custom.com/v1",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "large-model",
DefaultMaxTokens: 1000,
@@ -1000,13 +1000,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
func TestConfig_configureSelectedModels(t *testing.T) {
t.Run("should override defaults", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "abc",
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "larger-model",
DefaultMaxTokens: 2000,
@@ -1048,13 +1048,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
assert.Equal(t, int64(500), small.MaxTokens)
})
t.Run("should be possible to use multiple providers", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "abc",
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "large-model",
DefaultMaxTokens: 1000,
@@ -1070,7 +1070,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
APIKey: "abc",
DefaultLargeModelID: "a-large-model",
DefaultSmallModelID: "a-small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "a-large-model",
DefaultMaxTokens: 1000,
@@ -1111,13 +1111,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
})
t.Run("should override the max tokens only", func(t *testing.T) {
knownProviders := []provider.Provider{
knownProviders := []catwalk.Provider{
{
ID: "openai",
APIKey: "abc",
DefaultLargeModelID: "large-model",
DefaultSmallModelID: "small-model",
Models: []provider.Model{
Models: []catwalk.Model{
{
ID: "large-model",
DefaultMaxTokens: 1000,

View File

@@ -7,17 +7,16 @@ import (
"runtime"
"sync"
"github.com/charmbracelet/crush/internal/fur/client"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/catwalk/pkg/catwalk"
)
type ProviderClient interface {
GetProviders() ([]provider.Provider, error)
GetProviders() ([]catwalk.Provider, error)
}
var (
providerOnce sync.Once
providerList []provider.Provider
providerList []catwalk.Provider
)
// file to cache provider data
@@ -41,7 +40,7 @@ func providerCacheFileData() string {
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
}
func saveProvidersInCache(path string, providers []provider.Provider) error {
func saveProvidersInCache(path string, providers []catwalk.Provider) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
@@ -55,18 +54,18 @@ func saveProvidersInCache(path string, providers []provider.Provider) error {
return os.WriteFile(path, data, 0o644)
}
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var providers []provider.Provider
var providers []catwalk.Provider
err = json.Unmarshal(data, &providers)
return providers, err
}
func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
func loadProviders(path string, client ProviderClient) ([]catwalk.Provider, error) {
providers, err := client.GetProviders()
if err != nil {
fallbackToCache, err := loadProvidersFromCache(path)
@@ -82,11 +81,11 @@ func loadProviders(path string, client ProviderClient) ([]provider.Provider, err
return providers, nil
}
func Providers() ([]provider.Provider, error) {
return LoadProviders(client.New())
func Providers() ([]catwalk.Provider, error) {
return LoadProviders(catwalk.NewWithURL(catwalkURL))
}
func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
func LoadProviders(client ProviderClient) ([]catwalk.Provider, error) {
var err error
providerOnce.Do(func() {
providerList, err = loadProviders(providerCacheFileData(), client)

View File

@@ -6,7 +6,7 @@ import (
"os"
"testing"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/stretchr/testify/assert"
)
@@ -14,11 +14,11 @@ type mockProviderClient struct {
shouldFail bool
}
func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
if m.shouldFail {
return nil, errors.New("failed to load providers")
}
return []provider.Provider{
return []catwalk.Provider{
{
Name: "Mock",
},
@@ -43,7 +43,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
// store providers to a temporary file
oldProviders := []provider.Provider{
oldProviders := []catwalk.Provider{
{
Name: "OldProvider",
},

View File

@@ -1,63 +0,0 @@
// Package client provides a client for interacting with the fur service.
package client
import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/charmbracelet/crush/internal/fur/provider"
)
const defaultURL = "https://fur.charm.sh"
// Client represents a client for the fur service.
type Client struct {
baseURL string
httpClient *http.Client
}
// New creates a new client instance
// Uses FUR_URL environment variable or falls back to localhost:8080.
func New() *Client {
baseURL := os.Getenv("FUR_URL")
if baseURL == "" {
baseURL = defaultURL
}
return &Client{
baseURL: baseURL,
httpClient: &http.Client{},
}
}
// NewWithURL creates a new client with a specific URL.
func NewWithURL(url string) *Client {
return &Client{
baseURL: url,
httpClient: &http.Client{},
}
}
// GetProviders retrieves all available providers from the service.
func (c *Client) GetProviders() ([]provider.Provider, error) {
url := fmt.Sprintf("%s/providers", c.baseURL)
resp, err := c.httpClient.Get(url) //nolint:noctx
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var providers []provider.Provider
if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return providers, nil
}

View File

@@ -1,75 +0,0 @@
// Package provider provides types and constants for AI providers.
package provider
// Type represents the type of AI provider.
type Type string
// All the supported AI provider types.
const (
TypeOpenAI Type = "openai"
TypeAnthropic Type = "anthropic"
TypeGemini Type = "gemini"
TypeAzure Type = "azure"
TypeBedrock Type = "bedrock"
TypeVertexAI Type = "vertexai"
TypeXAI Type = "xai"
)
// InferenceProvider represents the inference provider identifier.
type InferenceProvider string
// All the inference providers supported by the system.
const (
InferenceProviderOpenAI InferenceProvider = "openai"
InferenceProviderAnthropic InferenceProvider = "anthropic"
InferenceProviderGemini InferenceProvider = "gemini"
InferenceProviderAzure InferenceProvider = "azure"
InferenceProviderBedrock InferenceProvider = "bedrock"
InferenceProviderVertexAI InferenceProvider = "vertexai"
InferenceProviderXAI InferenceProvider = "xai"
InferenceProviderGROQ InferenceProvider = "groq"
InferenceProviderOpenRouter InferenceProvider = "openrouter"
)
// Provider represents an AI provider configuration.
type Provider struct {
Name string `json:"name"`
ID InferenceProvider `json:"id"`
APIKey string `json:"api_key,omitempty"`
APIEndpoint string `json:"api_endpoint,omitempty"`
Type Type `json:"type,omitempty"`
DefaultLargeModelID string `json:"default_large_model_id,omitempty"`
DefaultSmallModelID string `json:"default_small_model_id,omitempty"`
Models []Model `json:"models,omitempty"`
}
// Model represents an AI model configuration.
type Model struct {
ID string `json:"id"`
Model string `json:"model"`
CostPer1MIn float64 `json:"cost_per_1m_in"`
CostPer1MOut float64 `json:"cost_per_1m_out"`
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
ContextWindow int64 `json:"context_window"`
DefaultMaxTokens int64 `json:"default_max_tokens"`
CanReason bool `json:"can_reason"`
HasReasoningEffort bool `json:"has_reasoning_efforts"`
DefaultReasoningEffort string `json:"default_reasoning_effort,omitempty"`
SupportsImages bool `json:"supports_attachments"`
}
// KnownProviders returns all the known inference providers.
func KnownProviders() []InferenceProvider {
return []InferenceProvider{
InferenceProviderOpenAI,
InferenceProviderAnthropic,
InferenceProviderGemini,
InferenceProviderAzure,
InferenceProviderBedrock,
InferenceProviderVertexAI,
InferenceProviderXAI,
InferenceProviderGROQ,
InferenceProviderOpenRouter,
}
}

View File

@@ -10,8 +10,8 @@ import (
"sync"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
fur "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
@@ -52,7 +52,7 @@ type AgentEvent struct {
type Service interface {
pubsub.Suscriber[AgentEvent]
Model() fur.Model
Model() catwalk.Model
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
Cancel(sessionID string)
CancelAll()
@@ -219,7 +219,7 @@ func NewAgent(
return agent, nil
}
func (a *agent) Model() fur.Model {
func (a *agent) Model() catwalk.Model {
return *config.Get().GetModelByType(a.agentCfg.Model)
}
@@ -638,7 +638,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
return nil
}
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
sess, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)

View File

@@ -9,17 +9,17 @@ import (
"runtime"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
)
func CoderPrompt(p string, contextFiles ...string) string {
var basePrompt string
switch p {
case string(provider.InferenceProviderOpenAI):
case string(catwalk.InferenceProviderOpenAI):
basePrompt = baseOpenAICoderPrompt
case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI):
case string(catwalk.InferenceProviderGemini), string(catwalk.InferenceProviderVertexAI):
basePrompt = baseGeminiCoderPrompt
default:
basePrompt = baseAnthropicCoderPrompt

View File

@@ -15,8 +15,8 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -71,7 +71,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
var contentBlocks []anthropic.ContentBlockParamUnion
contentBlocks = append(contentBlocks, content)
for _, binaryContent := range msg.BinaryContent() {
base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic)
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
contentBlocks = append(contentBlocks, imageBlock)
}
@@ -529,6 +529,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
}
}
func (a *anthropicClient) Model() provider.Model {
func (a *anthropicClient) Model() catwalk.Model {
return a.providerOptions.model(a.providerOptions.modelType)
}

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"strings"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -32,7 +32,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
}
}
opts.model = func(modelType config.SelectedModelType) provider.Model {
opts.model = func(modelType config.SelectedModelType) catwalk.Model {
model := config.Get().GetModelByType(modelType)
// Prefix the model name with region
@@ -88,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
return b.childProvider.stream(ctx, messages, tools)
}
func (b *bedrockClient) Model() provider.Model {
func (b *bedrockClient) Model() catwalk.Model {
return b.providerOptions.model(b.providerOptions.modelType)
}

View File

@@ -10,8 +10,8 @@ import (
"strings"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
"github.com/google/uuid"
@@ -463,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
}
}
func (g *geminiClient) Model() provider.Model {
func (g *geminiClient) Model() catwalk.Model {
return g.providerOptions.model(g.providerOptions.modelType)
}

View File

@@ -9,8 +9,8 @@ import (
"log/slog"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
"github.com/openai/openai-go"
@@ -66,7 +66,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
for _, binaryContent := range msg.BinaryContent() {
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
@@ -486,6 +486,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
}
}
func (o *openaiClient) Model() provider.Model {
func (o *openaiClient) Model() catwalk.Model {
return o.providerOptions.model(o.providerOptions.modelType)
}

View File

@@ -9,8 +9,8 @@ import (
"testing"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
"github.com/openai/openai-go"
@@ -55,10 +55,10 @@ func TestOpenAIClientStreamChoices(t *testing.T) {
modelType: config.SelectedModelTypeLarge,
apiKey: "test-key",
systemMessage: "test",
model: func(config.SelectedModelType) provider.Model {
return provider.Model{
ID: "test-model",
Model: "test-model",
model: func(config.SelectedModelType) catwalk.Model {
return catwalk.Model{
ID: "test-model",
Name: "test-model",
}
},
},

View File

@@ -4,8 +4,8 @@ import (
"context"
"fmt"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/message"
)
@@ -57,7 +57,7 @@ type Provider interface {
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() provider.Model
Model() catwalk.Model
}
type providerClientOptions struct {
@@ -65,7 +65,7 @@ type providerClientOptions struct {
config config.ProviderConfig
apiKey string
modelType config.SelectedModelType
model func(config.SelectedModelType) provider.Model
model func(config.SelectedModelType) catwalk.Model
disableCache bool
systemMessage string
maxTokens int64
@@ -80,7 +80,7 @@ type ProviderClient interface {
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
Model() provider.Model
Model() catwalk.Model
}
type baseProvider[C ProviderClient] struct {
@@ -109,7 +109,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
return p.client.stream(ctx, messages, tools)
}
func (p *baseProvider[C]) Model() provider.Model {
func (p *baseProvider[C]) Model() catwalk.Model {
return p.client.Model()
}
@@ -149,7 +149,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
apiKey: resolvedAPIKey,
extraHeaders: cfg.ExtraHeaders,
extraBody: cfg.ExtraBody,
model: func(tp config.SelectedModelType) provider.Model {
model: func(tp config.SelectedModelType) catwalk.Model {
return *config.Get().GetModelByType(tp)
},
}
@@ -157,37 +157,37 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
o(&clientOptions)
}
switch cfg.Type {
case provider.TypeAnthropic:
case catwalk.TypeAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
client: newAnthropicClient(clientOptions, false),
}, nil
case provider.TypeOpenAI:
case catwalk.TypeOpenAI:
return &baseProvider[OpenAIClient]{
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
case provider.TypeGemini:
case catwalk.TypeGemini:
return &baseProvider[GeminiClient]{
options: clientOptions,
client: newGeminiClient(clientOptions),
}, nil
case provider.TypeBedrock:
case catwalk.TypeBedrock:
return &baseProvider[BedrockClient]{
options: clientOptions,
client: newBedrockClient(clientOptions),
}, nil
case provider.TypeAzure:
case catwalk.TypeAzure:
return &baseProvider[AzureClient]{
options: clientOptions,
client: newAzureClient(clientOptions),
}, nil
case provider.TypeVertexAI:
case catwalk.TypeVertexAI:
return &baseProvider[VertexAIClient]{
options: clientOptions,
client: newVertexAIClient(clientOptions),
}, nil
case provider.TypeXAI:
case catwalk.TypeXAI:
clientOptions.baseURL = "https://api.x.ai/v1"
return &baseProvider[OpenAIClient]{
options: clientOptions,

View File

@@ -5,7 +5,7 @@ import (
"slices"
"time"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/catwalk/pkg/catwalk"
)
type MessageRole string
@@ -74,9 +74,9 @@ type BinaryContent struct {
Data []byte
}
func (bc BinaryContent) String(p provider.InferenceProvider) string {
func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
if p == provider.InferenceProviderOpenAI {
if p == catwalk.InferenceProviderOpenAI {
return "data:" + bc.MIMEType + ";base64," + base64Encoded
}
return base64Encoded

View File

@@ -8,11 +8,11 @@ import (
"github.com/charmbracelet/bubbles/v2/viewport"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/x/ansi"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/tui/components/anim"
"github.com/charmbracelet/crush/internal/tui/components/core"
@@ -369,11 +369,11 @@ func (m *assistantSectionModel) View() string {
model := config.Get().GetModel(m.message.Provider, m.message.Model)
if model == nil {
// This means the model is not configured anymore
model = &provider.Model{
Model: "Unknown Model",
model = &catwalk.Model{
Name: "Unknown Model",
}
}
modelFormatted := t.S().Muted.Render(model.Model)
modelFormatted := t.S().Muted.Render(model.Name)
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
return t.S().Base.PaddingLeft(2).Render(
core.Section(assistant, m.width-2),

View File

@@ -9,10 +9,10 @@ import (
"sync"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/diff"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/lsp/protocol"
@@ -897,7 +897,7 @@ func (s *sidebarCmp) currentModelBlock() string {
t := styles.CurrentTheme()
modelIcon := t.S().Base.Foreground(t.FgSubtle).Render(styles.ModelIcon)
modelName := t.S().Text.Render(model.Model)
modelName := t.S().Text.Render(model.Name)
modelInfo := fmt.Sprintf("%s %s", modelIcon, modelName)
parts := []string{
modelInfo,
@@ -905,14 +905,14 @@ func (s *sidebarCmp) currentModelBlock() string {
if model.CanReason {
reasoningInfoStyle := t.S().Subtle.PaddingLeft(2)
switch modelProvider.Type {
case provider.TypeOpenAI:
case catwalk.TypeOpenAI:
reasoningEffort := model.DefaultReasoningEffort
if selectedModel.ReasoningEffort != "" {
reasoningEffort = selectedModel.ReasoningEffort
}
formatter := cases.Title(language.English, cases.NoLower)
parts = append(parts, reasoningInfoStyle.Render(formatter.String(fmt.Sprintf("Reasoning %s", reasoningEffort))))
case provider.TypeAnthropic:
case catwalk.TypeAnthropic:
formatter := cases.Title(language.English, cases.NoLower)
if selectedModel.Think {
parts = append(parts, reasoningInfoStyle.Render(formatter.String("Thinking on")))

View File

@@ -10,8 +10,8 @@ import (
"github.com/charmbracelet/bubbles/v2/key"
"github.com/charmbracelet/bubbles/v2/spinner"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/completions"
@@ -109,7 +109,7 @@ func (s *splashCmp) SetOnboarding(onboarding bool) {
if err != nil {
return
}
filteredProviders := []provider.Provider{}
filteredProviders := []catwalk.Provider{}
simpleProviders := []string{
"anthropic",
"openai",
@@ -407,7 +407,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
return nil
}
func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
providers, err := config.Providers()
if err != nil {
return nil, err

View File

@@ -4,10 +4,10 @@ import (
"github.com/charmbracelet/bubbles/v2/help"
"github.com/charmbracelet/bubbles/v2/key"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/lipgloss/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/tui/components/chat"
"github.com/charmbracelet/crush/internal/tui/components/completions"
@@ -270,7 +270,7 @@ func (c *commandDialogCmp) defaultCommands() []Command {
providerCfg := cfg.GetProviderForModel(agentCfg.Model)
model := cfg.GetModelByType(agentCfg.Model)
if providerCfg != nil && model != nil &&
providerCfg.Type == provider.TypeAnthropic && model.CanReason {
providerCfg.Type == catwalk.TypeAnthropic && model.CanReason {
selectedModel := cfg.Models[agentCfg.Model]
status := "Enable"
if selectedModel.Think {

View File

@@ -5,8 +5,8 @@ import (
"slices"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/tui/components/completions"
"github.com/charmbracelet/crush/internal/tui/components/core/list"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
@@ -18,7 +18,7 @@ import (
type ModelListComponent struct {
list list.ListModel
modelType int
providers []provider.Provider
providers []catwalk.Provider
}
func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
@@ -109,19 +109,19 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
// Check if this provider is not in the known providers list
if !slices.ContainsFunc(knownProviders, func(p provider.Provider) bool { return p.ID == provider.InferenceProvider(providerID) }) {
if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
// Convert config provider to provider.Provider format
configProvider := provider.Provider{
configProvider := catwalk.Provider{
Name: providerConfig.Name,
ID: provider.InferenceProvider(providerID),
Models: make([]provider.Model, len(providerConfig.Models)),
ID: catwalk.InferenceProvider(providerID),
Models: make([]catwalk.Model, len(providerConfig.Models)),
}
// Convert models
for i, model := range providerConfig.Models {
configProvider.Models[i] = provider.Model{
configProvider.Models[i] = catwalk.Model{
ID: model.ID,
Model: model.Model,
Name: model.Name,
CostPer1MIn: model.CostPer1MIn,
CostPer1MOut: model.CostPer1MOut,
CostPer1MInCached: model.CostPer1MInCached,
@@ -144,7 +144,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
section.SetInfo(configured)
modelItems = append(modelItems, section)
for _, model := range configProvider.Models {
modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
Provider: configProvider,
Model: model,
}))
@@ -179,7 +179,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
}
modelItems = append(modelItems, section)
for _, model := range provider.Models {
modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
Provider: provider,
Model: model,
}))
@@ -201,6 +201,6 @@ func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
m.list.SetFilterPlaceholder(placeholder)
}
func (m *ModelListComponent) SetProviders(providers []provider.Provider) {
func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
m.providers = providers
}

View File

@@ -8,8 +8,8 @@ import (
"github.com/charmbracelet/bubbles/v2/key"
"github.com/charmbracelet/bubbles/v2/spinner"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/tui/components/completions"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/list"
@@ -48,8 +48,8 @@ type ModelDialog interface {
}
type ModelOption struct {
Provider provider.Provider
Model provider.Model
Provider catwalk.Provider
Model catwalk.Model
}
type modelDialogCmp struct {
@@ -363,7 +363,7 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
return false
}
func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
providers, err := config.Providers()
if err != nil {
return nil, err

View File

@@ -279,7 +279,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if model.SupportsImages {
return p, util.CmdHandler(OpenFilePickerMsg{})
} else {
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Model)
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name)
}
case key.Matches(msg, p.keyMap.Tab):
if p.session.ID == "" {