mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
Merge pull request #292 from charmbracelet/catwalk
feat: use new catwalk
This commit is contained in:
@@ -9,8 +9,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
@@ -70,7 +70,7 @@ type ProviderConfig struct {
|
||||
// The provider's API endpoint.
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
|
||||
Type provider.Type `json:"type,omitempty"`
|
||||
Type catwalk.Type `json:"type,omitempty"`
|
||||
// The provider's API key.
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
// Marks the provider as disabled.
|
||||
@@ -85,7 +85,7 @@ type ProviderConfig struct {
|
||||
ExtraParams map[string]string `json:"-"`
|
||||
|
||||
// The provider models
|
||||
Models []provider.Model `json:"models,omitempty"`
|
||||
Models []catwalk.Model `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
type MCPType string
|
||||
@@ -250,8 +250,8 @@ type Config struct {
|
||||
Agents map[string]Agent `json:"-"`
|
||||
// TODO: find a better way to do this this should probably not be part of the config
|
||||
resolver VariableResolver
|
||||
dataConfigDir string `json:"-"`
|
||||
knownProviders []provider.Provider `json:"-"`
|
||||
dataConfigDir string `json:"-"`
|
||||
knownProviders []catwalk.Provider `json:"-"`
|
||||
}
|
||||
|
||||
func (c *Config) WorkingDir() string {
|
||||
@@ -273,7 +273,7 @@ func (c *Config) IsConfigured() bool {
|
||||
return len(c.EnabledProviders()) > 0
|
||||
}
|
||||
|
||||
func (c *Config) GetModel(provider, model string) *provider.Model {
|
||||
func (c *Config) GetModel(provider, model string) *catwalk.Model {
|
||||
if providerConfig, ok := c.Providers[provider]; ok {
|
||||
for _, m := range providerConfig.Models {
|
||||
if m.ID == model {
|
||||
@@ -295,7 +295,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
|
||||
func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
|
||||
model, ok := c.Models[modelType]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -303,7 +303,7 @@ func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
|
||||
return c.GetModel(model.Provider, model.Model)
|
||||
}
|
||||
|
||||
func (c *Config) LargeModel() *provider.Model {
|
||||
func (c *Config) LargeModel() *catwalk.Model {
|
||||
model, ok := c.Models[SelectedModelTypeLarge]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -311,7 +311,7 @@ func (c *Config) LargeModel() *provider.Model {
|
||||
return c.GetModel(model.Provider, model.Model)
|
||||
}
|
||||
|
||||
func (c *Config) SmallModel() *provider.Model {
|
||||
func (c *Config) SmallModel() *catwalk.Model {
|
||||
model, ok := c.Models[SelectedModelTypeSmall]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -381,7 +381,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var foundProvider *provider.Provider
|
||||
var foundProvider *catwalk.Provider
|
||||
for _, p := range c.knownProviders {
|
||||
if string(p.ID) == providerID {
|
||||
foundProvider = &p
|
||||
@@ -450,14 +450,14 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
|
||||
headers := make(map[string]string)
|
||||
apiKey, _ := resolver.ResolveValue(c.APIKey)
|
||||
switch c.Type {
|
||||
case provider.TypeOpenAI:
|
||||
case catwalk.TypeOpenAI:
|
||||
baseURL, _ := resolver.ResolveValue(c.BaseURL)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
testURL = baseURL + "/models"
|
||||
headers["Authorization"] = "Bearer " + apiKey
|
||||
case provider.TypeAnthropic:
|
||||
case catwalk.TypeAnthropic:
|
||||
baseURL, _ := resolver.ResolveValue(c.BaseURL)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com/v1"
|
||||
|
||||
@@ -11,13 +11,14 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/charmbracelet/crush/internal/fur/client"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/log"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
const catwalkURL = "https://catwalk.charm.sh"
|
||||
|
||||
// LoadReader config via io.Reader.
|
||||
func LoadReader(fd io.Reader) (*Config, error) {
|
||||
data, err := io.ReadAll(fd)
|
||||
@@ -61,8 +62,8 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
cfg.Options.Debug,
|
||||
)
|
||||
|
||||
// Load known providers, this loads the config from fur
|
||||
providers, err := LoadProviders(client.New())
|
||||
// Load known providers, this loads the config from catwalk
|
||||
providers, err := LoadProviders(catwalk.NewWithURL(catwalkURL))
|
||||
if err != nil || len(providers) == 0 {
|
||||
return nil, fmt.Errorf("failed to load providers: %w", err)
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, p := range cfg.Providers {
|
||||
if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
|
||||
if p.Type == catwalk.TypeOpenAI || p.Type == catwalk.TypeAnthropic {
|
||||
wg.Add(1)
|
||||
go func(provider ProviderConfig) {
|
||||
defer wg.Done()
|
||||
@@ -117,7 +118,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
|
||||
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
|
||||
knownProviderNames := make(map[string]bool)
|
||||
for _, p := range knownProviders {
|
||||
knownProviderNames[string(p.ID)] = true
|
||||
@@ -136,7 +137,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
p.APIKey = config.APIKey
|
||||
}
|
||||
if len(config.Models) > 0 {
|
||||
models := []provider.Model{}
|
||||
models := []catwalk.Model{}
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, model := range config.Models {
|
||||
@@ -144,8 +145,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
continue
|
||||
}
|
||||
seen[model.ID] = true
|
||||
if model.Model == "" {
|
||||
model.Model = model.ID
|
||||
if model.Name == "" {
|
||||
model.Name = model.ID
|
||||
}
|
||||
models = append(models, model)
|
||||
}
|
||||
@@ -154,8 +155,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
continue
|
||||
}
|
||||
seen[model.ID] = true
|
||||
if model.Model == "" {
|
||||
model.Model = model.ID
|
||||
if model.Name == "" {
|
||||
model.Name = model.ID
|
||||
}
|
||||
models = append(models, model)
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
|
||||
switch p.ID {
|
||||
// Handle specific providers that require additional configuration
|
||||
case provider.InferenceProviderVertexAI:
|
||||
case catwalk.InferenceProviderVertexAI:
|
||||
if !hasVertexCredentials(env) {
|
||||
if configExists {
|
||||
slog.Warn("Skipping Vertex AI provider due to missing credentials")
|
||||
@@ -188,7 +189,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
}
|
||||
prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
|
||||
prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
|
||||
case provider.InferenceProviderAzure:
|
||||
case catwalk.InferenceProviderAzure:
|
||||
endpoint, err := resolver.ResolveValue(p.APIEndpoint)
|
||||
if err != nil || endpoint == "" {
|
||||
if configExists {
|
||||
@@ -199,7 +200,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
}
|
||||
prepared.BaseURL = endpoint
|
||||
prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION")
|
||||
case provider.InferenceProviderBedrock:
|
||||
case catwalk.InferenceProviderBedrock:
|
||||
if !hasAWSCredentials(env) {
|
||||
if configExists {
|
||||
slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
|
||||
@@ -239,7 +240,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
}
|
||||
// default to OpenAI if not set
|
||||
if providerConfig.Type == "" {
|
||||
providerConfig.Type = provider.TypeOpenAI
|
||||
providerConfig.Type = catwalk.TypeOpenAI
|
||||
}
|
||||
|
||||
if providerConfig.Disable {
|
||||
@@ -260,7 +261,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
delete(c.Providers, id)
|
||||
continue
|
||||
}
|
||||
if providerConfig.Type != provider.TypeOpenAI {
|
||||
if providerConfig.Type != catwalk.TypeOpenAI {
|
||||
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
|
||||
delete(c.Providers, id)
|
||||
continue
|
||||
@@ -315,7 +316,7 @@ func (c *Config) setDefaults(workingDir string) {
|
||||
c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
|
||||
}
|
||||
|
||||
func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
|
||||
func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
|
||||
if len(knownProviders) == 0 && len(c.Providers) == 0 {
|
||||
err = fmt.Errorf("no providers configured, please configure at least one provider")
|
||||
return
|
||||
@@ -384,7 +385,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Config) configureSelectedModels(knownProviders []provider.Provider) error {
|
||||
func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
|
||||
defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select default models: %w", err)
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -54,12 +54,12 @@ func TestConfig_setDefaults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProviders(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -80,12 +80,12 @@ func TestConfig_configureProviders(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersWithOverride(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -96,10 +96,10 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
|
||||
"openai": {
|
||||
APIKey: "xyz",
|
||||
BaseURL: "https://api.openai.com/v2",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "test-model",
|
||||
Model: "Updated",
|
||||
ID: "test-model",
|
||||
Name: "Updated",
|
||||
},
|
||||
{
|
||||
ID: "another-model",
|
||||
@@ -122,16 +122,16 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
|
||||
assert.Equal(t, "xyz", cfg.Providers["openai"].APIKey)
|
||||
assert.Equal(t, "https://api.openai.com/v2", cfg.Providers["openai"].BaseURL)
|
||||
assert.Len(t, cfg.Providers["openai"].Models, 2)
|
||||
assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Model)
|
||||
assert.Equal(t, "Updated", cfg.Providers["openai"].Models[0].Name)
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -142,7 +142,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "xyz",
|
||||
BaseURL: "https://api.someendpoint.com/v2",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "test-model",
|
||||
},
|
||||
@@ -172,12 +172,12 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
}},
|
||||
},
|
||||
@@ -201,12 +201,12 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
}},
|
||||
},
|
||||
@@ -223,12 +223,12 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "some-random-model",
|
||||
}},
|
||||
},
|
||||
@@ -246,12 +246,12 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -278,12 +278,12 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -304,12 +304,12 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -329,12 +329,12 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersSetProviderID(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -450,12 +450,12 @@ func TestConfig_IsConfigured(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -489,7 +489,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
Providers: map[string]ProviderConfig{
|
||||
"custom": {
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -502,7 +502,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, cfg.Providers, 1)
|
||||
@@ -515,7 +515,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
Providers: map[string]ProviderConfig{
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -525,7 +525,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, cfg.Providers, 0)
|
||||
@@ -539,7 +539,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{},
|
||||
Models: []catwalk.Model{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -547,7 +547,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, cfg.Providers, 0)
|
||||
@@ -562,7 +562,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Type: "unsupported",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -572,7 +572,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, cfg.Providers, 0)
|
||||
@@ -586,8 +586,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Type: provider.TypeOpenAI,
|
||||
Models: []provider.Model{{
|
||||
Type: catwalk.TypeOpenAI,
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -597,7 +597,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, cfg.Providers, 1)
|
||||
@@ -614,9 +614,9 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Type: provider.TypeOpenAI,
|
||||
Type: catwalk.TypeOpenAI,
|
||||
Disable: true,
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -626,7 +626,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, cfg.Providers, 0)
|
||||
@@ -637,12 +637,12 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -670,12 +670,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
}},
|
||||
},
|
||||
@@ -701,12 +701,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -732,12 +732,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "$MISSING_ENDPOINT",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -767,13 +767,13 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
|
||||
func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -803,13 +803,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
assert.Equal(t, int64(500), small.MaxTokens)
|
||||
})
|
||||
t.Run("should error if no providers configured", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING_KEY",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -833,13 +833,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("should error if model is missing", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "not-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -863,13 +863,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should configure the default models with a custom provider", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING", // will not be included in the config
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "not-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -887,7 +887,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "model",
|
||||
DefaultMaxTokens: 600,
|
||||
@@ -912,13 +912,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should fail if no model configured", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING", // will not be included in the config
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "not-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -936,7 +936,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{},
|
||||
Models: []catwalk.Model{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -949,13 +949,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("should use the default provider first", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "set",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -973,7 +973,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -1000,13 +1000,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
|
||||
func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
t.Run("should override defaults", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "larger-model",
|
||||
DefaultMaxTokens: 2000,
|
||||
@@ -1048,13 +1048,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
assert.Equal(t, int64(500), small.MaxTokens)
|
||||
})
|
||||
t.Run("should be possible to use multiple providers", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -1070,7 +1070,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "a-large-model",
|
||||
DefaultSmallModelID: "a-small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "a-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -1111,13 +1111,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should override the max tokens only", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
|
||||
@@ -7,17 +7,16 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/client"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
)
|
||||
|
||||
type ProviderClient interface {
|
||||
GetProviders() ([]provider.Provider, error)
|
||||
GetProviders() ([]catwalk.Provider, error)
|
||||
}
|
||||
|
||||
var (
|
||||
providerOnce sync.Once
|
||||
providerList []provider.Provider
|
||||
providerList []catwalk.Provider
|
||||
)
|
||||
|
||||
// file to cache provider data
|
||||
@@ -41,7 +40,7 @@ func providerCacheFileData() string {
|
||||
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
|
||||
}
|
||||
|
||||
func saveProvidersInCache(path string, providers []provider.Provider) error {
|
||||
func saveProvidersInCache(path string, providers []catwalk.Provider) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
@@ -55,18 +54,18 @@ func saveProvidersInCache(path string, providers []provider.Provider) error {
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
|
||||
func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var providers []provider.Provider
|
||||
var providers []catwalk.Provider
|
||||
err = json.Unmarshal(data, &providers)
|
||||
return providers, err
|
||||
}
|
||||
|
||||
func loadProviders(path string, client ProviderClient) ([]provider.Provider, error) {
|
||||
func loadProviders(path string, client ProviderClient) ([]catwalk.Provider, error) {
|
||||
providers, err := client.GetProviders()
|
||||
if err != nil {
|
||||
fallbackToCache, err := loadProvidersFromCache(path)
|
||||
@@ -82,11 +81,11 @@ func loadProviders(path string, client ProviderClient) ([]provider.Provider, err
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func Providers() ([]provider.Provider, error) {
|
||||
return LoadProviders(client.New())
|
||||
func Providers() ([]catwalk.Provider, error) {
|
||||
return LoadProviders(catwalk.NewWithURL(catwalkURL))
|
||||
}
|
||||
|
||||
func LoadProviders(client ProviderClient) ([]provider.Provider, error) {
|
||||
func LoadProviders(client ProviderClient) ([]catwalk.Provider, error) {
|
||||
var err error
|
||||
providerOnce.Do(func() {
|
||||
providerList, err = loadProviders(providerCacheFileData(), client)
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -14,11 +14,11 @@ type mockProviderClient struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
|
||||
func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
|
||||
if m.shouldFail {
|
||||
return nil, errors.New("failed to load providers")
|
||||
}
|
||||
return []provider.Provider{
|
||||
return []catwalk.Provider{
|
||||
{
|
||||
Name: "Mock",
|
||||
},
|
||||
@@ -43,7 +43,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: true}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
// store providers to a temporary file
|
||||
oldProviders := []provider.Provider{
|
||||
oldProviders := []catwalk.Provider{
|
||||
{
|
||||
Name: "OldProvider",
|
||||
},
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
// Package client provides a client for interacting with the fur service.
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
const defaultURL = "https://fur.charm.sh"
|
||||
|
||||
// Client represents a client for the fur service.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// New creates a new client instance
|
||||
// Uses FUR_URL environment variable or falls back to localhost:8080.
|
||||
func New() *Client {
|
||||
baseURL := os.Getenv("FUR_URL")
|
||||
if baseURL == "" {
|
||||
baseURL = defaultURL
|
||||
}
|
||||
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithURL creates a new client with a specific URL.
|
||||
func NewWithURL(url string) *Client {
|
||||
return &Client{
|
||||
baseURL: url,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// GetProviders retrieves all available providers from the service.
|
||||
func (c *Client) GetProviders() ([]provider.Provider, error) {
|
||||
url := fmt.Sprintf("%s/providers", c.baseURL)
|
||||
|
||||
resp, err := c.httpClient.Get(url) //nolint:noctx
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var providers []provider.Provider
|
||||
if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
// Package provider provides types and constants for AI providers.
|
||||
package provider
|
||||
|
||||
// Type represents the type of AI provider.
|
||||
type Type string
|
||||
|
||||
// All the supported AI provider types.
|
||||
const (
|
||||
TypeOpenAI Type = "openai"
|
||||
TypeAnthropic Type = "anthropic"
|
||||
TypeGemini Type = "gemini"
|
||||
TypeAzure Type = "azure"
|
||||
TypeBedrock Type = "bedrock"
|
||||
TypeVertexAI Type = "vertexai"
|
||||
TypeXAI Type = "xai"
|
||||
)
|
||||
|
||||
// InferenceProvider represents the inference provider identifier.
|
||||
type InferenceProvider string
|
||||
|
||||
// All the inference providers supported by the system.
|
||||
const (
|
||||
InferenceProviderOpenAI InferenceProvider = "openai"
|
||||
InferenceProviderAnthropic InferenceProvider = "anthropic"
|
||||
InferenceProviderGemini InferenceProvider = "gemini"
|
||||
InferenceProviderAzure InferenceProvider = "azure"
|
||||
InferenceProviderBedrock InferenceProvider = "bedrock"
|
||||
InferenceProviderVertexAI InferenceProvider = "vertexai"
|
||||
InferenceProviderXAI InferenceProvider = "xai"
|
||||
InferenceProviderGROQ InferenceProvider = "groq"
|
||||
InferenceProviderOpenRouter InferenceProvider = "openrouter"
|
||||
)
|
||||
|
||||
// Provider represents an AI provider configuration.
|
||||
type Provider struct {
|
||||
Name string `json:"name"`
|
||||
ID InferenceProvider `json:"id"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
APIEndpoint string `json:"api_endpoint,omitempty"`
|
||||
Type Type `json:"type,omitempty"`
|
||||
DefaultLargeModelID string `json:"default_large_model_id,omitempty"`
|
||||
DefaultSmallModelID string `json:"default_small_model_id,omitempty"`
|
||||
Models []Model `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model represents an AI model configuration.
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
DefaultMaxTokens int64 `json:"default_max_tokens"`
|
||||
CanReason bool `json:"can_reason"`
|
||||
HasReasoningEffort bool `json:"has_reasoning_efforts"`
|
||||
DefaultReasoningEffort string `json:"default_reasoning_effort,omitempty"`
|
||||
SupportsImages bool `json:"supports_attachments"`
|
||||
}
|
||||
|
||||
// KnownProviders returns all the known inference providers.
|
||||
func KnownProviders() []InferenceProvider {
|
||||
return []InferenceProvider{
|
||||
InferenceProviderOpenAI,
|
||||
InferenceProviderAnthropic,
|
||||
InferenceProviderGemini,
|
||||
InferenceProviderAzure,
|
||||
InferenceProviderBedrock,
|
||||
InferenceProviderVertexAI,
|
||||
InferenceProviderXAI,
|
||||
InferenceProviderGROQ,
|
||||
InferenceProviderOpenRouter,
|
||||
}
|
||||
}
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
fur "github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/llm/provider"
|
||||
@@ -52,7 +52,7 @@ type AgentEvent struct {
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[AgentEvent]
|
||||
Model() fur.Model
|
||||
Model() catwalk.Model
|
||||
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
|
||||
Cancel(sessionID string)
|
||||
CancelAll()
|
||||
@@ -219,7 +219,7 @@ func NewAgent(
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (a *agent) Model() fur.Model {
|
||||
func (a *agent) Model() catwalk.Model {
|
||||
return *config.Get().GetModelByType(a.agentCfg.Model)
|
||||
}
|
||||
|
||||
@@ -638,7 +638,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
|
||||
sess, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
|
||||
@@ -9,17 +9,17 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
)
|
||||
|
||||
func CoderPrompt(p string, contextFiles ...string) string {
|
||||
var basePrompt string
|
||||
switch p {
|
||||
case string(provider.InferenceProviderOpenAI):
|
||||
case string(catwalk.InferenceProviderOpenAI):
|
||||
basePrompt = baseOpenAICoderPrompt
|
||||
case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI):
|
||||
case string(catwalk.InferenceProviderGemini), string(catwalk.InferenceProviderVertexAI):
|
||||
basePrompt = baseGeminiCoderPrompt
|
||||
default:
|
||||
basePrompt = baseAnthropicCoderPrompt
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -71,7 +71,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
contentBlocks = append(contentBlocks, content)
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
|
||||
base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic)
|
||||
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
|
||||
contentBlocks = append(contentBlocks, imageBlock)
|
||||
}
|
||||
@@ -529,6 +529,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) Model() provider.Model {
|
||||
func (a *anthropicClient) Model() catwalk.Model {
|
||||
return a.providerOptions.model(a.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -32,7 +32,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
}
|
||||
}
|
||||
|
||||
opts.model = func(modelType config.SelectedModelType) provider.Model {
|
||||
opts.model = func(modelType config.SelectedModelType) catwalk.Model {
|
||||
model := config.Get().GetModelByType(modelType)
|
||||
|
||||
// Prefix the model name with region
|
||||
@@ -88,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
|
||||
return b.childProvider.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (b *bedrockClient) Model() provider.Model {
|
||||
func (b *bedrockClient) Model() catwalk.Model {
|
||||
return b.providerOptions.model(b.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/google/uuid"
|
||||
@@ -463,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) Model() provider.Model {
|
||||
func (g *geminiClient) Model() catwalk.Model {
|
||||
return g.providerOptions.model(g.providerOptions.modelType)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/openai/openai-go"
|
||||
@@ -66,7 +66,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
|
||||
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
|
||||
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
|
||||
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
|
||||
@@ -486,6 +486,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openaiClient) Model() provider.Model {
|
||||
func (o *openaiClient) Model() catwalk.Model {
|
||||
return o.providerOptions.model(o.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/openai/openai-go"
|
||||
@@ -55,10 +55,10 @@ func TestOpenAIClientStreamChoices(t *testing.T) {
|
||||
modelType: config.SelectedModelTypeLarge,
|
||||
apiKey: "test-key",
|
||||
systemMessage: "test",
|
||||
model: func(config.SelectedModelType) provider.Model {
|
||||
return provider.Model{
|
||||
ID: "test-model",
|
||||
Model: "test-model",
|
||||
model: func(config.SelectedModelType) catwalk.Model {
|
||||
return catwalk.Model{
|
||||
ID: "test-model",
|
||||
Name: "test-model",
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -57,7 +57,7 @@ type Provider interface {
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() provider.Model
|
||||
Model() catwalk.Model
|
||||
}
|
||||
|
||||
type providerClientOptions struct {
|
||||
@@ -65,7 +65,7 @@ type providerClientOptions struct {
|
||||
config config.ProviderConfig
|
||||
apiKey string
|
||||
modelType config.SelectedModelType
|
||||
model func(config.SelectedModelType) provider.Model
|
||||
model func(config.SelectedModelType) catwalk.Model
|
||||
disableCache bool
|
||||
systemMessage string
|
||||
maxTokens int64
|
||||
@@ -80,7 +80,7 @@ type ProviderClient interface {
|
||||
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() provider.Model
|
||||
Model() catwalk.Model
|
||||
}
|
||||
|
||||
type baseProvider[C ProviderClient] struct {
|
||||
@@ -109,7 +109,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
|
||||
return p.client.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) Model() provider.Model {
|
||||
func (p *baseProvider[C]) Model() catwalk.Model {
|
||||
return p.client.Model()
|
||||
}
|
||||
|
||||
@@ -149,7 +149,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
apiKey: resolvedAPIKey,
|
||||
extraHeaders: cfg.ExtraHeaders,
|
||||
extraBody: cfg.ExtraBody,
|
||||
model: func(tp config.SelectedModelType) provider.Model {
|
||||
model: func(tp config.SelectedModelType) catwalk.Model {
|
||||
return *config.Get().GetModelByType(tp)
|
||||
},
|
||||
}
|
||||
@@ -157,37 +157,37 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
o(&clientOptions)
|
||||
}
|
||||
switch cfg.Type {
|
||||
case provider.TypeAnthropic:
|
||||
case catwalk.TypeAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
client: newAnthropicClient(clientOptions, false),
|
||||
}, nil
|
||||
case provider.TypeOpenAI:
|
||||
case catwalk.TypeOpenAI:
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeGemini:
|
||||
case catwalk.TypeGemini:
|
||||
return &baseProvider[GeminiClient]{
|
||||
options: clientOptions,
|
||||
client: newGeminiClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeBedrock:
|
||||
case catwalk.TypeBedrock:
|
||||
return &baseProvider[BedrockClient]{
|
||||
options: clientOptions,
|
||||
client: newBedrockClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeAzure:
|
||||
case catwalk.TypeAzure:
|
||||
return &baseProvider[AzureClient]{
|
||||
options: clientOptions,
|
||||
client: newAzureClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeVertexAI:
|
||||
case catwalk.TypeVertexAI:
|
||||
return &baseProvider[VertexAIClient]{
|
||||
options: clientOptions,
|
||||
client: newVertexAIClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeXAI:
|
||||
case catwalk.TypeXAI:
|
||||
clientOptions.baseURL = "https://api.x.ai/v1"
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
@@ -74,9 +74,9 @@ type BinaryContent struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (bc BinaryContent) String(p provider.InferenceProvider) string {
|
||||
func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
|
||||
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
|
||||
if p == provider.InferenceProviderOpenAI {
|
||||
if p == catwalk.InferenceProviderOpenAI {
|
||||
return "data:" + bc.MIMEType + ";base64," + base64Encoded
|
||||
}
|
||||
return base64Encoded
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
|
||||
"github.com/charmbracelet/bubbles/v2/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/anim"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
@@ -369,11 +369,11 @@ func (m *assistantSectionModel) View() string {
|
||||
model := config.Get().GetModel(m.message.Provider, m.message.Model)
|
||||
if model == nil {
|
||||
// This means the model is not configured anymore
|
||||
model = &provider.Model{
|
||||
Model: "Unknown Model",
|
||||
model = &catwalk.Model{
|
||||
Name: "Unknown Model",
|
||||
}
|
||||
}
|
||||
modelFormatted := t.S().Muted.Render(model.Model)
|
||||
modelFormatted := t.S().Muted.Render(model.Name)
|
||||
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
|
||||
return t.S().Base.PaddingLeft(2).Render(
|
||||
core.Section(assistant, m.width-2),
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
"sync"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/diff"
|
||||
"github.com/charmbracelet/crush/internal/fsext"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/lsp/protocol"
|
||||
@@ -897,7 +897,7 @@ func (s *sidebarCmp) currentModelBlock() string {
|
||||
t := styles.CurrentTheme()
|
||||
|
||||
modelIcon := t.S().Base.Foreground(t.FgSubtle).Render(styles.ModelIcon)
|
||||
modelName := t.S().Text.Render(model.Model)
|
||||
modelName := t.S().Text.Render(model.Name)
|
||||
modelInfo := fmt.Sprintf("%s %s", modelIcon, modelName)
|
||||
parts := []string{
|
||||
modelInfo,
|
||||
@@ -905,14 +905,14 @@ func (s *sidebarCmp) currentModelBlock() string {
|
||||
if model.CanReason {
|
||||
reasoningInfoStyle := t.S().Subtle.PaddingLeft(2)
|
||||
switch modelProvider.Type {
|
||||
case provider.TypeOpenAI:
|
||||
case catwalk.TypeOpenAI:
|
||||
reasoningEffort := model.DefaultReasoningEffort
|
||||
if selectedModel.ReasoningEffort != "" {
|
||||
reasoningEffort = selectedModel.ReasoningEffort
|
||||
}
|
||||
formatter := cases.Title(language.English, cases.NoLower)
|
||||
parts = append(parts, reasoningInfoStyle.Render(formatter.String(fmt.Sprintf("Reasoning %s", reasoningEffort))))
|
||||
case provider.TypeAnthropic:
|
||||
case catwalk.TypeAnthropic:
|
||||
formatter := cases.Title(language.English, cases.NoLower)
|
||||
if selectedModel.Think {
|
||||
parts = append(parts, reasoningInfoStyle.Render(formatter.String("Thinking on")))
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
"github.com/charmbracelet/bubbles/v2/spinner"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/chat"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
@@ -109,7 +109,7 @@ func (s *splashCmp) SetOnboarding(onboarding bool) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
filteredProviders := []provider.Provider{}
|
||||
filteredProviders := []catwalk.Provider{}
|
||||
simpleProviders := []string{
|
||||
"anthropic",
|
||||
"openai",
|
||||
@@ -407,7 +407,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
|
||||
func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
|
||||
providers, err := config.Providers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"github.com/charmbracelet/bubbles/v2/help"
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/chat"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
@@ -270,7 +270,7 @@ func (c *commandDialogCmp) defaultCommands() []Command {
|
||||
providerCfg := cfg.GetProviderForModel(agentCfg.Model)
|
||||
model := cfg.GetModelByType(agentCfg.Model)
|
||||
if providerCfg != nil && model != nil &&
|
||||
providerCfg.Type == provider.TypeAnthropic && model.CanReason {
|
||||
providerCfg.Type == catwalk.TypeAnthropic && model.CanReason {
|
||||
selectedModel := cfg.Models[agentCfg.Model]
|
||||
status := "Enable"
|
||||
if selectedModel.Think {
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"slices"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core/list"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
type ModelListComponent struct {
|
||||
list list.ListModel
|
||||
modelType int
|
||||
providers []provider.Provider
|
||||
providers []catwalk.Provider
|
||||
}
|
||||
|
||||
func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
|
||||
@@ -109,19 +109,19 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
|
||||
}
|
||||
|
||||
// Check if this provider is not in the known providers list
|
||||
if !slices.ContainsFunc(knownProviders, func(p provider.Provider) bool { return p.ID == provider.InferenceProvider(providerID) }) {
|
||||
if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
|
||||
// Convert config provider to provider.Provider format
|
||||
configProvider := provider.Provider{
|
||||
configProvider := catwalk.Provider{
|
||||
Name: providerConfig.Name,
|
||||
ID: provider.InferenceProvider(providerID),
|
||||
Models: make([]provider.Model, len(providerConfig.Models)),
|
||||
ID: catwalk.InferenceProvider(providerID),
|
||||
Models: make([]catwalk.Model, len(providerConfig.Models)),
|
||||
}
|
||||
|
||||
// Convert models
|
||||
for i, model := range providerConfig.Models {
|
||||
configProvider.Models[i] = provider.Model{
|
||||
configProvider.Models[i] = catwalk.Model{
|
||||
ID: model.ID,
|
||||
Model: model.Model,
|
||||
Name: model.Name,
|
||||
CostPer1MIn: model.CostPer1MIn,
|
||||
CostPer1MOut: model.CostPer1MOut,
|
||||
CostPer1MInCached: model.CostPer1MInCached,
|
||||
@@ -144,7 +144,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
|
||||
section.SetInfo(configured)
|
||||
modelItems = append(modelItems, section)
|
||||
for _, model := range configProvider.Models {
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
|
||||
Provider: configProvider,
|
||||
Model: model,
|
||||
}))
|
||||
@@ -179,7 +179,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
|
||||
}
|
||||
modelItems = append(modelItems, section)
|
||||
for _, model := range provider.Models {
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}))
|
||||
@@ -201,6 +201,6 @@ func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
|
||||
m.list.SetFilterPlaceholder(placeholder)
|
||||
}
|
||||
|
||||
func (m *ModelListComponent) SetProviders(providers []provider.Provider) {
|
||||
func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
|
||||
m.providers = providers
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
"github.com/charmbracelet/bubbles/v2/spinner"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core/list"
|
||||
@@ -48,8 +48,8 @@ type ModelDialog interface {
|
||||
}
|
||||
|
||||
type ModelOption struct {
|
||||
Provider provider.Provider
|
||||
Model provider.Model
|
||||
Provider catwalk.Provider
|
||||
Model catwalk.Model
|
||||
}
|
||||
|
||||
type modelDialogCmp struct {
|
||||
@@ -363,7 +363,7 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
|
||||
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
|
||||
providers, err := config.Providers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -279,7 +279,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if model.SupportsImages {
|
||||
return p, util.CmdHandler(OpenFilePickerMsg{})
|
||||
} else {
|
||||
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Model)
|
||||
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name)
|
||||
}
|
||||
case key.Matches(msg, p.keyMap.Tab):
|
||||
if p.session.ID == "" {
|
||||
|
||||
Reference in New Issue
Block a user