Merge remote-tracking branch 'origin/main' into tool_improvements

This commit is contained in:
Kujtim Hoxha
2025-07-28 21:51:45 +02:00
18 changed files with 489 additions and 412 deletions

2
.github/CODEOWNERS vendored
View File

@@ -1 +1 @@
* @kujtimiihoxha
*.go @kujtimiihoxha

View File

@@ -19,7 +19,8 @@
- **Interfaces**: Define interfaces in consuming packages, keep them small and focused
- **Structs**: Use struct embedding for composition, group related fields
- **Constants**: Use typed constants with iota for enums, group in const blocks
- **Testing**: Use testify/assert and testify/require, parallel tests with `t.Parallel()`
- **Testing**: Use testify's `require` package, parallel tests with `t.Parallel()`,
`t.SetEnv()` to set environment variables.
- **JSON tags**: Use snake_case for JSON field names
- **File permissions**: Use octal notation (0o755, 0o644) for file permissions
- **Comments**: End comments in periods unless comments are at the end of the line.

115
LICENSE
View File

@@ -1,6 +1,119 @@
# Functional Source License, Version 1.1, MIT Future License
## Abbreviation
FSL-1.1-MIT
## Notice
Copyright 2025 Charmbracelet, Inc
## Terms and Conditions
### Licensor ("We")
The party offering the Software under these Terms and Conditions.
### The Software
The "Software" is each version of the software that we make available under
these Terms and Conditions, as indicated by our inclusion of these Terms and
Conditions with the Software.
### License Grant
Subject to your compliance with this License Grant and the Patents,
Redistribution and Trademark clauses below, we hereby grant you the right to
use, copy, modify, create derivative works, publicly perform, publicly display
and redistribute the Software for any Permitted Purpose identified below.
### Permitted Purpose
A Permitted Purpose is any purpose other than a Competing Use. A Competing Use
means making the Software available to others in a commercial product or
service that:
1. substitutes for the Software;
2. substitutes for any other product or service we offer using the Software
that exists as of the date we make the Software available; or
3. offers the same or substantially similar functionality as the Software.
Permitted Purposes specifically include using the Software:
1. for your internal use and access;
2. for non-commercial education;
3. for non-commercial research; and
4. in connection with professional services that you provide to a licensee
using the Software in accordance with these Terms and Conditions.
### Patents
To the extent your use for a Permitted Purpose would necessarily infringe our
patents, the license grant above includes a license under our patents. If you
make a claim against any party that the Software infringes or contributes to
the infringement of any patent, then your patent license to the Software ends
immediately.
### Redistribution
The Terms and Conditions apply to all copies, modifications and derivatives of
the Software.
If you redistribute any copies, modifications or derivatives of the Software,
you must include a copy of or a link to these Terms and Conditions and not
remove any copyright notices provided in or with the Software.
### Disclaimer
THE SOFTWARE IS PROVIDED "AS IS" AND WITHOUT WARRANTIES OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING WITHOUT LIMITATION WARRANTIES OF FITNESS FOR A PARTICULAR
PURPOSE, MERCHANTABILITY, TITLE OR NON-INFRINGEMENT.
IN NO EVENT WILL WE HAVE ANY LIABILITY TO YOU ARISING OUT OF OR RELATED TO THE
SOFTWARE, INCLUDING INDIRECT, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES,
EVEN IF WE HAVE BEEN INFORMED OF THEIR POSSIBILITY IN ADVANCE.
### Trademarks
Except for displaying the License Details and identifying us as the origin of
the Software, you have no right under these Terms and Conditions to use our
trademarks, trade names, service marks or product names.
## Grant of Future License
We hereby irrevocably grant you an additional license to use the Software under
the MIT license that is effective on the second anniversary of the date we make
the Software available. On or after that date, you may use the Software under
the MIT license, in which case the following will apply:
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---
MIT License
Copyright (c) 2025 Kujtim Hoxha
Copyright (c) 2025-03-21 - 2025-05-30 Kujtim Hoxha
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

2
go.mod
View File

@@ -31,7 +31,7 @@ require (
github.com/ncruces/go-sqlite3 v0.25.0
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/nxadm/tail v1.4.11
github.com/openai/openai-go v1.8.2
github.com/openai/openai-go v1.11.1
github.com/pressly/goose/v3 v3.24.2
github.com/qjebbs/go-jsons v0.0.0-20221222033332-a534c5fc1c4c
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06

4
go.sum
View File

@@ -201,8 +201,8 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY=
github.com/nxadm/tail v1.4.11/go.mod h1:OTaG3NK980DZzxbRq6lEuzgU+mug70nY11sMd4JXXHc=
github.com/openai/openai-go v1.8.2 h1:UqSkJ1vCOPUpz9Ka5tS0324EJFEuOvMc+lA/EarJWP8=
github.com/openai/openai-go v1.8.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/openai/openai-go v1.11.1 h1:fTQ4Sr9eoRiWFAoHzXiZZpVi6KtLeoTMyGrcOCudjNU=
github.com/openai/openai-go v1.11.1/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -11,7 +11,7 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
@@ -28,12 +28,12 @@ func TestConfig_LoadFromReaders(t *testing.T) {
loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
assert.NoError(t, err)
assert.NotNil(t, loadedConfig)
assert.Equal(t, 1, loadedConfig.Providers.Len())
require.NoError(t, err)
require.NotNil(t, loadedConfig)
require.Equal(t, 1, loadedConfig.Providers.Len())
pc, _ := loadedConfig.Providers.Get("openai")
assert.Equal(t, "key2", pc.APIKey)
assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
require.Equal(t, "key2", pc.APIKey)
require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
}
func TestConfig_setDefaults(t *testing.T) {
@@ -41,18 +41,18 @@ func TestConfig_setDefaults(t *testing.T) {
cfg.setDefaults("/tmp")
assert.NotNil(t, cfg.Options)
assert.NotNil(t, cfg.Options.TUI)
assert.NotNil(t, cfg.Options.ContextPaths)
assert.NotNil(t, cfg.Providers)
assert.NotNil(t, cfg.Models)
assert.NotNil(t, cfg.LSP)
assert.NotNil(t, cfg.MCP)
assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
require.NotNil(t, cfg.Options)
require.NotNil(t, cfg.Options.TUI)
require.NotNil(t, cfg.Options.ContextPaths)
require.NotNil(t, cfg.Providers)
require.NotNil(t, cfg.Models)
require.NotNil(t, cfg.LSP)
require.NotNil(t, cfg.MCP)
require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
for _, path := range defaultContextPaths {
assert.Contains(t, cfg.Options.ContextPaths, path)
require.Contains(t, cfg.Options.ContextPaths, path)
}
assert.Equal(t, "/tmp", cfg.workingDir)
require.Equal(t, "/tmp", cfg.workingDir)
}
func TestConfig_configureProviders(t *testing.T) {
@@ -74,12 +74,12 @@ func TestConfig_configureProviders(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, 1, cfg.Providers.Len())
require.NoError(t, err)
require.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder
pc, _ := cfg.Providers.Get("openai")
assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
}
func TestConfig_configureProvidersWithOverride(t *testing.T) {
@@ -117,15 +117,15 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, 1, cfg.Providers.Len())
require.NoError(t, err)
require.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder
pc, _ := cfg.Providers.Get("openai")
assert.Equal(t, "xyz", pc.APIKey)
assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
assert.Len(t, pc.Models, 2)
assert.Equal(t, "Updated", pc.Models[0].Name)
require.Equal(t, "xyz", pc.APIKey)
require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
require.Len(t, pc.Models, 2)
require.Equal(t, "Updated", pc.Models[0].Name)
}
func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
@@ -159,20 +159,20 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
// Should be to because of the env variable
assert.Equal(t, cfg.Providers.Len(), 2)
require.Equal(t, cfg.Providers.Len(), 2)
// We want to make sure that we keep the configured API key as a placeholder
pc, _ := cfg.Providers.Get("custom")
assert.Equal(t, "xyz", pc.APIKey)
require.Equal(t, "xyz", pc.APIKey)
// Make sure we set the ID correctly
assert.Equal(t, "custom", pc.ID)
assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
assert.Len(t, pc.Models, 1)
require.Equal(t, "custom", pc.ID)
require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
require.Len(t, pc.Models, 1)
_, ok := cfg.Providers.Get("openai")
assert.True(t, ok, "OpenAI provider should still be present")
require.True(t, ok, "OpenAI provider should still be present")
}
func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
@@ -195,13 +195,13 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
bedrockProvider, ok := cfg.Providers.Get("bedrock")
assert.True(t, ok, "Bedrock provider should be present")
assert.Len(t, bedrockProvider.Models, 1)
assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
require.True(t, ok, "Bedrock provider should be present")
require.Len(t, bedrockProvider.Models, 1)
require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
}
func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
@@ -221,9 +221,9 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
// Provider should not be configured without credentials
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
@@ -246,7 +246,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.Error(t, err)
require.Error(t, err)
}
func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
@@ -270,15 +270,15 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
vertexProvider, ok := cfg.Providers.Get("vertexai")
assert.True(t, ok, "VertexAI provider should be present")
assert.Len(t, vertexProvider.Models, 1)
assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
require.True(t, ok, "VertexAI provider should be present")
require.Len(t, vertexProvider.Models, 1)
require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
}
func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
@@ -302,9 +302,9 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
// Provider should not be configured without proper credentials
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
@@ -327,9 +327,9 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
// Provider should not be configured without project
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
}
func TestConfig_configureProvidersSetProviderID(t *testing.T) {
@@ -351,12 +351,12 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
require.NoError(t, err)
require.Equal(t, cfg.Providers.Len(), 1)
// Provider ID should be set
pc, _ := cfg.Providers.Get("openai")
assert.Equal(t, "openai", pc.ID)
require.Equal(t, "openai", pc.ID)
}
func TestConfig_EnabledProviders(t *testing.T) {
@@ -377,7 +377,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
}
enabled := cfg.EnabledProviders()
assert.Len(t, enabled, 2)
require.Len(t, enabled, 2)
})
t.Run("some providers disabled", func(t *testing.T) {
@@ -397,8 +397,8 @@ func TestConfig_EnabledProviders(t *testing.T) {
}
enabled := cfg.EnabledProviders()
assert.Len(t, enabled, 1)
assert.Equal(t, "openai", enabled[0].ID)
require.Len(t, enabled, 1)
require.Equal(t, "openai", enabled[0].ID)
})
t.Run("empty providers map", func(t *testing.T) {
@@ -407,7 +407,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
}
enabled := cfg.EnabledProviders()
assert.Len(t, enabled, 0)
require.Len(t, enabled, 0)
})
}
@@ -423,7 +423,7 @@ func TestConfig_IsConfigured(t *testing.T) {
}),
}
assert.True(t, cfg.IsConfigured())
require.True(t, cfg.IsConfigured())
})
t.Run("returns false when no providers are configured", func(t *testing.T) {
@@ -431,7 +431,7 @@ func TestConfig_IsConfigured(t *testing.T) {
Providers: csync.NewMap[string, ProviderConfig](),
}
assert.False(t, cfg.IsConfigured())
require.False(t, cfg.IsConfigured())
})
t.Run("returns false when all providers are disabled", func(t *testing.T) {
@@ -450,7 +450,7 @@ func TestConfig_IsConfigured(t *testing.T) {
}),
}
assert.False(t, cfg.IsConfigured())
require.False(t, cfg.IsConfigured())
})
}
@@ -480,12 +480,12 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
// Provider should be removed from config when disabled
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("openai")
assert.False(t, exists)
require.False(t, exists)
}
func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
@@ -508,11 +508,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
require.Equal(t, cfg.Providers.Len(), 1)
_, exists := cfg.Providers.Get("custom")
assert.True(t, exists)
require.True(t, exists)
})
t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
@@ -531,11 +531,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
require.False(t, exists)
})
t.Run("custom provider with no models is removed", func(t *testing.T) {
@@ -553,11 +553,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
require.False(t, exists)
})
t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
@@ -578,11 +578,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
require.False(t, exists)
})
t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
@@ -603,14 +603,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
require.Equal(t, cfg.Providers.Len(), 1)
customProvider, exists := cfg.Providers.Get("custom")
assert.True(t, exists)
assert.Equal(t, "custom", customProvider.ID)
assert.Equal(t, "test-key", customProvider.APIKey)
assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
require.True(t, exists)
require.Equal(t, "custom", customProvider.ID)
require.Equal(t, "test-key", customProvider.APIKey)
require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
})
t.Run("custom anthropic provider is supported", func(t *testing.T) {
@@ -631,15 +631,15 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
require.Equal(t, cfg.Providers.Len(), 1)
customProvider, exists := cfg.Providers.Get("custom-anthropic")
assert.True(t, exists)
assert.Equal(t, "custom-anthropic", customProvider.ID)
assert.Equal(t, "test-key", customProvider.APIKey)
assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
require.True(t, exists)
require.Equal(t, "custom-anthropic", customProvider.ID)
require.Equal(t, "test-key", customProvider.APIKey)
require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
})
t.Run("disabled custom provider is removed", func(t *testing.T) {
@@ -661,11 +661,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom")
assert.False(t, exists)
require.False(t, exists)
})
}
@@ -696,11 +696,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("vertexai")
assert.False(t, exists)
require.False(t, exists)
})
t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
@@ -727,11 +727,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("bedrock")
assert.False(t, exists)
require.False(t, exists)
})
t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
@@ -758,11 +758,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0)
require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("openai")
assert.False(t, exists)
require.False(t, exists)
})
t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
@@ -791,11 +791,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1)
require.Equal(t, cfg.Providers.Len(), 1)
_, exists := cfg.Providers.Get("openai")
assert.True(t, exists)
require.True(t, exists)
})
}
@@ -825,16 +825,16 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
assert.NoError(t, err)
assert.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(1000), large.MaxTokens)
assert.Equal(t, "small-model", small.Model)
assert.Equal(t, "openai", small.Provider)
assert.Equal(t, int64(500), small.MaxTokens)
require.NoError(t, err)
require.Equal(t, "large-model", large.Model)
require.Equal(t, "openai", large.Provider)
require.Equal(t, int64(1000), large.MaxTokens)
require.Equal(t, "small-model", small.Model)
require.Equal(t, "openai", small.Provider)
require.Equal(t, int64(500), small.MaxTokens)
})
t.Run("should error if no providers configured", func(t *testing.T) {
knownProviders := []catwalk.Provider{
@@ -861,10 +861,10 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
assert.Error(t, err)
require.Error(t, err)
})
t.Run("should error if model is missing", func(t *testing.T) {
knownProviders := []catwalk.Provider{
@@ -891,9 +891,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
assert.Error(t, err)
require.Error(t, err)
})
t.Run("should configure the default models with a custom provider", func(t *testing.T) {
@@ -934,15 +934,15 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
assert.NoError(t, err)
assert.Equal(t, "model", large.Model)
assert.Equal(t, "custom", large.Provider)
assert.Equal(t, int64(600), large.MaxTokens)
assert.Equal(t, "model", small.Model)
assert.Equal(t, "custom", small.Provider)
assert.Equal(t, int64(600), small.MaxTokens)
require.NoError(t, err)
require.Equal(t, "model", large.Model)
require.Equal(t, "custom", large.Provider)
require.Equal(t, int64(600), large.MaxTokens)
require.Equal(t, "model", small.Model)
require.Equal(t, "custom", small.Provider)
require.Equal(t, int64(600), small.MaxTokens)
})
t.Run("should fail if no model configured", func(t *testing.T) {
@@ -978,9 +978,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders)
assert.Error(t, err)
require.Error(t, err)
})
t.Run("should use the default provider first", func(t *testing.T) {
knownProviders := []catwalk.Provider{
@@ -1020,15 +1020,15 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders)
assert.NoError(t, err)
assert.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(1000), large.MaxTokens)
assert.Equal(t, "small-model", small.Model)
assert.Equal(t, "openai", small.Provider)
assert.Equal(t, int64(500), small.MaxTokens)
require.NoError(t, err)
require.Equal(t, "large-model", large.Model)
require.Equal(t, "openai", large.Provider)
require.Equal(t, int64(1000), large.MaxTokens)
require.Equal(t, "small-model", small.Model)
require.Equal(t, "openai", small.Provider)
require.Equal(t, int64(500), small.MaxTokens)
})
}
@@ -1068,18 +1068,18 @@ func TestConfig_configureSelectedModels(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
err = cfg.configureSelectedModels(knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall]
assert.Equal(t, "larger-model", large.Model)
assert.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(2000), large.MaxTokens)
assert.Equal(t, "small-model", small.Model)
assert.Equal(t, "openai", small.Provider)
assert.Equal(t, int64(500), small.MaxTokens)
require.Equal(t, "larger-model", large.Model)
require.Equal(t, "openai", large.Provider)
require.Equal(t, int64(2000), large.MaxTokens)
require.Equal(t, "small-model", small.Model)
require.Equal(t, "openai", small.Provider)
require.Equal(t, int64(500), small.MaxTokens)
})
t.Run("should be possible to use multiple providers", func(t *testing.T) {
knownProviders := []catwalk.Provider{
@@ -1130,18 +1130,18 @@ func TestConfig_configureSelectedModels(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
err = cfg.configureSelectedModels(knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall]
assert.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(1000), large.MaxTokens)
assert.Equal(t, "a-small-model", small.Model)
assert.Equal(t, "anthropic", small.Provider)
assert.Equal(t, int64(300), small.MaxTokens)
require.Equal(t, "large-model", large.Model)
require.Equal(t, "openai", large.Provider)
require.Equal(t, int64(1000), large.MaxTokens)
require.Equal(t, "a-small-model", small.Model)
require.Equal(t, "anthropic", small.Provider)
require.Equal(t, int64(300), small.MaxTokens)
})
t.Run("should override the max tokens only", func(t *testing.T) {
@@ -1175,13 +1175,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
err = cfg.configureSelectedModels(knownProviders)
assert.NoError(t, err)
require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge]
assert.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(100), large.MaxTokens)
require.Equal(t, "large-model", large.Model)
require.Equal(t, "openai", large.Provider)
require.Equal(t, int64(100), large.MaxTokens)
})
}

View File

@@ -7,7 +7,7 @@ import (
"testing"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProviderClient struct {
@@ -29,14 +29,14 @@ func TestProvider_loadProvidersNoIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: false}
tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(client, tmpPath)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
require.NoError(t, err)
require.NotNil(t, providers)
require.Len(t, providers, 1)
// check if file got saved
fileInfo, err := os.Stat(tmpPath)
assert.NoError(t, err)
assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
require.NoError(t, err)
require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
}
func TestProvider_loadProvidersWithIssues(t *testing.T) {
@@ -58,16 +58,16 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
t.Fatalf("Failed to write old providers to file: %v", err)
}
providers, err := loadProviders(client, tmpPath)
assert.NoError(t, err)
assert.NotNil(t, providers)
assert.Len(t, providers, 1)
assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
require.NoError(t, err)
require.NotNil(t, providers)
require.Len(t, providers, 1)
require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
}
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(client, tmpPath)
assert.Error(t, err)
assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
require.Error(t, err)
require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
}

View File

@@ -6,7 +6,7 @@ import (
"testing"
"github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockShell implements the Shell interface for testing
@@ -85,10 +85,10 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
result, err := resolver.ResolveValue(tt.value)
if tt.expectError {
assert.Error(t, err)
require.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
require.NoError(t, err)
require.Equal(t, tt.expected, result)
}
})
}
@@ -250,10 +250,10 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
result, err := resolver.ResolveValue(tt.value)
if tt.expectError {
assert.Error(t, err)
require.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
require.NoError(t, err)
require.Equal(t, tt.expected, result)
}
})
}
@@ -306,10 +306,10 @@ func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
result, err := resolver.ResolveValue(tt.value)
if tt.expectError {
assert.Error(t, err)
require.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
require.NoError(t, err)
require.Equal(t, tt.expected, result)
}
})
}
@@ -319,14 +319,14 @@ func TestNewShellVariableResolver(t *testing.T) {
testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
resolver := NewShellVariableResolver(testEnv)
assert.NotNil(t, resolver)
assert.Implements(t, (*VariableResolver)(nil), resolver)
require.NotNil(t, resolver)
require.Implements(t, (*VariableResolver)(nil), resolver)
}
func TestNewEnvironmentVariableResolver(t *testing.T) {
testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
resolver := NewEnvironmentVariableResolver(testEnv)
assert.NotNil(t, resolver)
assert.Implements(t, (*VariableResolver)(nil), resolver)
require.NotNil(t, resolver)
require.Implements(t, (*VariableResolver)(nil), resolver)
}

3
internal/csync/doc.go Normal file
View File

@@ -0,0 +1,3 @@
// Package csync provides concurrent data structures for safe access in
// multi-threaded environments.
package csync

View File

@@ -6,16 +6,16 @@ import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewMap(t *testing.T) {
t.Parallel()
m := NewMap[string, int]()
assert.NotNil(t, m)
assert.NotNil(t, m.inner)
assert.Equal(t, 0, m.Len())
require.NotNil(t, m)
require.NotNil(t, m.inner)
require.Equal(t, 0, m.Len())
}
func TestNewMapFrom(t *testing.T) {
@@ -27,13 +27,13 @@ func TestNewMapFrom(t *testing.T) {
}
m := NewMapFrom(original)
assert.NotNil(t, m)
assert.Equal(t, original, m.inner)
assert.Equal(t, 2, m.Len())
require.NotNil(t, m)
require.Equal(t, original, m.inner)
require.Equal(t, 2, m.Len())
value, ok := m.Get("key1")
assert.True(t, ok)
assert.Equal(t, 1, value)
require.True(t, ok)
require.Equal(t, 1, value)
}
func TestMap_Set(t *testing.T) {
@@ -43,15 +43,15 @@ func TestMap_Set(t *testing.T) {
m.Set("key1", 42)
value, ok := m.Get("key1")
assert.True(t, ok)
assert.Equal(t, 42, value)
assert.Equal(t, 1, m.Len())
require.True(t, ok)
require.Equal(t, 42, value)
require.Equal(t, 1, m.Len())
m.Set("key1", 100)
value, ok = m.Get("key1")
assert.True(t, ok)
assert.Equal(t, 100, value)
assert.Equal(t, 1, m.Len())
require.True(t, ok)
require.Equal(t, 100, value)
require.Equal(t, 1, m.Len())
}
func TestMap_Get(t *testing.T) {
@@ -60,13 +60,13 @@ func TestMap_Get(t *testing.T) {
m := NewMap[string, int]()
value, ok := m.Get("nonexistent")
assert.False(t, ok)
assert.Equal(t, 0, value)
require.False(t, ok)
require.Equal(t, 0, value)
m.Set("key1", 42)
value, ok = m.Get("key1")
assert.True(t, ok)
assert.Equal(t, 42, value)
require.True(t, ok)
require.Equal(t, 42, value)
}
func TestMap_Del(t *testing.T) {
@@ -76,38 +76,38 @@ func TestMap_Del(t *testing.T) {
m.Set("key1", 42)
m.Set("key2", 100)
assert.Equal(t, 2, m.Len())
require.Equal(t, 2, m.Len())
m.Del("key1")
_, ok := m.Get("key1")
assert.False(t, ok)
assert.Equal(t, 1, m.Len())
require.False(t, ok)
require.Equal(t, 1, m.Len())
value, ok := m.Get("key2")
assert.True(t, ok)
assert.Equal(t, 100, value)
require.True(t, ok)
require.Equal(t, 100, value)
m.Del("nonexistent")
assert.Equal(t, 1, m.Len())
require.Equal(t, 1, m.Len())
}
func TestMap_Len(t *testing.T) {
t.Parallel()
m := NewMap[string, int]()
assert.Equal(t, 0, m.Len())
require.Equal(t, 0, m.Len())
m.Set("key1", 1)
assert.Equal(t, 1, m.Len())
require.Equal(t, 1, m.Len())
m.Set("key2", 2)
assert.Equal(t, 2, m.Len())
require.Equal(t, 2, m.Len())
m.Del("key1")
assert.Equal(t, 1, m.Len())
require.Equal(t, 1, m.Len())
m.Del("key2")
assert.Equal(t, 0, m.Len())
require.Equal(t, 0, m.Len())
}
func TestMap_Take(t *testing.T) {
@@ -117,19 +117,19 @@ func TestMap_Take(t *testing.T) {
m.Set("key1", 42)
m.Set("key2", 100)
assert.Equal(t, 2, m.Len())
require.Equal(t, 2, m.Len())
value, ok := m.Take("key1")
assert.True(t, ok)
assert.Equal(t, 42, value)
assert.Equal(t, 1, m.Len())
require.True(t, ok)
require.Equal(t, 42, value)
require.Equal(t, 1, m.Len())
_, exists := m.Get("key1")
assert.False(t, exists)
require.False(t, exists)
value, ok = m.Get("key2")
assert.True(t, ok)
assert.Equal(t, 100, value)
require.True(t, ok)
require.Equal(t, 100, value)
}
func TestMap_Take_NonexistentKey(t *testing.T) {
@@ -139,13 +139,13 @@ func TestMap_Take_NonexistentKey(t *testing.T) {
m.Set("key1", 42)
value, ok := m.Take("nonexistent")
assert.False(t, ok)
assert.Equal(t, 0, value)
assert.Equal(t, 1, m.Len())
require.False(t, ok)
require.Equal(t, 0, value)
require.Equal(t, 1, m.Len())
value, ok = m.Get("key1")
assert.True(t, ok)
assert.Equal(t, 42, value)
require.True(t, ok)
require.Equal(t, 42, value)
}
func TestMap_Take_EmptyMap(t *testing.T) {
@@ -154,9 +154,9 @@ func TestMap_Take_EmptyMap(t *testing.T) {
m := NewMap[string, int]()
value, ok := m.Take("key1")
assert.False(t, ok)
assert.Equal(t, 0, value)
assert.Equal(t, 0, m.Len())
require.False(t, ok)
require.Equal(t, 0, value)
require.Equal(t, 0, m.Len())
}
func TestMap_Take_SameKeyTwice(t *testing.T) {
@@ -166,14 +166,14 @@ func TestMap_Take_SameKeyTwice(t *testing.T) {
m.Set("key1", 42)
value, ok := m.Take("key1")
assert.True(t, ok)
assert.Equal(t, 42, value)
assert.Equal(t, 0, m.Len())
require.True(t, ok)
require.Equal(t, 42, value)
require.Equal(t, 0, m.Len())
value, ok = m.Take("key1")
assert.False(t, ok)
assert.Equal(t, 0, value)
assert.Equal(t, 0, m.Len())
require.False(t, ok)
require.Equal(t, 0, value)
require.Equal(t, 0, m.Len())
}
func TestMap_Seq2(t *testing.T) {
@@ -186,10 +186,10 @@ func TestMap_Seq2(t *testing.T) {
collected := maps.Collect(m.Seq2())
assert.Equal(t, 3, len(collected))
assert.Equal(t, 1, collected["key1"])
assert.Equal(t, 2, collected["key2"])
assert.Equal(t, 3, collected["key3"])
require.Equal(t, 3, len(collected))
require.Equal(t, 1, collected["key1"])
require.Equal(t, 2, collected["key2"])
require.Equal(t, 3, collected["key3"])
}
func TestMap_Seq2_EarlyReturn(t *testing.T) {
@@ -208,7 +208,7 @@ func TestMap_Seq2_EarlyReturn(t *testing.T) {
}
}
assert.Equal(t, 2, count)
require.Equal(t, 2, count)
}
func TestMap_Seq2_EmptyMap(t *testing.T) {
@@ -221,7 +221,7 @@ func TestMap_Seq2_EmptyMap(t *testing.T) {
count++
}
assert.Equal(t, 0, count)
require.Equal(t, 0, count)
}
func TestMap_Seq(t *testing.T) {
@@ -237,10 +237,10 @@ func TestMap_Seq(t *testing.T) {
collected = append(collected, v)
}
assert.Equal(t, 3, len(collected))
assert.Contains(t, collected, 1)
assert.Contains(t, collected, 2)
assert.Contains(t, collected, 3)
require.Equal(t, 3, len(collected))
require.Contains(t, collected, 1)
require.Contains(t, collected, 2)
require.Contains(t, collected, 3)
}
func TestMap_Seq_EarlyReturn(t *testing.T) {
@@ -259,7 +259,7 @@ func TestMap_Seq_EarlyReturn(t *testing.T) {
}
}
assert.Equal(t, 2, count)
require.Equal(t, 2, count)
}
func TestMap_Seq_EmptyMap(t *testing.T) {
@@ -272,7 +272,7 @@ func TestMap_Seq_EmptyMap(t *testing.T) {
count++
}
assert.Equal(t, 0, count)
require.Equal(t, 0, count)
}
func TestMap_MarshalJSON(t *testing.T) {
@@ -283,16 +283,16 @@ func TestMap_MarshalJSON(t *testing.T) {
m.Set("key2", 2)
data, err := json.Marshal(m)
assert.NoError(t, err)
require.NoError(t, err)
result := &Map[string, int]{}
err = json.Unmarshal(data, result)
assert.NoError(t, err)
assert.Equal(t, 2, result.Len())
require.NoError(t, err)
require.Equal(t, 2, result.Len())
v1, _ := result.Get("key1")
v2, _ := result.Get("key2")
assert.Equal(t, 1, v1)
assert.Equal(t, 2, v2)
require.Equal(t, 1, v1)
require.Equal(t, 2, v2)
}
func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
@@ -301,8 +301,8 @@ func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
m := NewMap[string, int]()
data, err := json.Marshal(m)
assert.NoError(t, err)
assert.Equal(t, "{}", string(data))
require.NoError(t, err)
require.Equal(t, "{}", string(data))
}
func TestMap_UnmarshalJSON(t *testing.T) {
@@ -312,16 +312,16 @@ func TestMap_UnmarshalJSON(t *testing.T) {
m := NewMap[string, int]()
err := json.Unmarshal([]byte(jsonData), m)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, 2, m.Len())
require.Equal(t, 2, m.Len())
value, ok := m.Get("key1")
assert.True(t, ok)
assert.Equal(t, 1, value)
require.True(t, ok)
require.Equal(t, 1, value)
value, ok = m.Get("key2")
assert.True(t, ok)
assert.Equal(t, 2, value)
require.True(t, ok)
require.Equal(t, 2, value)
}
func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
@@ -331,8 +331,8 @@ func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
m := NewMap[string, int]()
err := json.Unmarshal([]byte(jsonData), m)
assert.NoError(t, err)
assert.Equal(t, 0, m.Len())
require.NoError(t, err)
require.Equal(t, 0, m.Len())
}
func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
@@ -342,7 +342,7 @@ func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
m := NewMap[string, int]()
err := json.Unmarshal([]byte(jsonData), m)
assert.Error(t, err)
require.Error(t, err)
}
func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
@@ -353,15 +353,15 @@ func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
jsonData := `{"key1": 1, "key2": 2}`
err := json.Unmarshal([]byte(jsonData), m)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, 2, m.Len())
require.Equal(t, 2, m.Len())
_, ok := m.Get("existing")
assert.False(t, ok)
require.False(t, ok)
value, ok := m.Get("key1")
assert.True(t, ok)
assert.Equal(t, 1, value)
require.True(t, ok)
require.Equal(t, 1, value)
}
func TestMap_JSONRoundTrip(t *testing.T) {
@@ -373,18 +373,18 @@ func TestMap_JSONRoundTrip(t *testing.T) {
original.Set("key3", 3)
data, err := json.Marshal(original)
assert.NoError(t, err)
require.NoError(t, err)
restored := NewMap[string, int]()
err = json.Unmarshal(data, restored)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, original.Len(), restored.Len())
require.Equal(t, original.Len(), restored.Len())
for k, v := range original.Seq2() {
restoredValue, ok := restored.Get(k)
assert.True(t, ok)
assert.Equal(t, v, restoredValue)
require.True(t, ok)
require.Equal(t, v, restoredValue)
}
}
@@ -405,15 +405,15 @@ func TestMap_ConcurrentAccess(t *testing.T) {
key := id*numOperations + j
m.Set(key, key*2)
value, ok := m.Get(key)
assert.True(t, ok)
assert.Equal(t, key*2, value)
require.True(t, ok)
require.Equal(t, key*2, value)
}
}(i)
}
wg.Wait()
assert.Equal(t, numGoroutines*numOperations, m.Len())
require.Equal(t, numGoroutines*numOperations, m.Len())
}
func TestMap_ConcurrentReadWrite(t *testing.T) {
@@ -438,7 +438,7 @@ func TestMap_ConcurrentReadWrite(t *testing.T) {
key := j % 1000
value, ok := m.Get(key)
if ok {
assert.Equal(t, key, value)
require.Equal(t, key, value)
}
_ = m.Len()
}
@@ -478,10 +478,10 @@ func TestMap_ConcurrentSeq2(t *testing.T) {
defer wg.Done()
count := 0
for k, v := range m.Seq2() {
assert.Equal(t, k*2, v)
require.Equal(t, k*2, v)
count++
}
assert.Equal(t, 100, count)
require.Equal(t, 100, count)
}()
}
@@ -509,9 +509,9 @@ func TestMap_ConcurrentSeq(t *testing.T) {
values[v] = true
count++
}
assert.Equal(t, 100, count)
require.Equal(t, 100, count)
for i := range 100 {
assert.True(t, values[i*2])
require.True(t, values[i*2])
}
}()
}
@@ -548,19 +548,19 @@ func TestMap_ConcurrentTake(t *testing.T) {
wg.Wait()
assert.Equal(t, 0, m.Len())
require.Equal(t, 0, m.Len())
allTaken := make(map[int]bool)
for _, workerTaken := range taken {
for _, value := range workerTaken {
assert.False(t, allTaken[value], "Value %d was taken multiple times", value)
require.False(t, allTaken[value], "Value %d was taken multiple times", value)
allTaken[value] = true
}
}
assert.Equal(t, numItems, len(allTaken))
require.Equal(t, numItems, len(allTaken))
for i := range numItems {
assert.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
}
}
@@ -570,20 +570,20 @@ func TestMap_TypeSafety(t *testing.T) {
stringIntMap := NewMap[string, int]()
stringIntMap.Set("key", 42)
value, ok := stringIntMap.Get("key")
assert.True(t, ok)
assert.Equal(t, 42, value)
require.True(t, ok)
require.Equal(t, 42, value)
intStringMap := NewMap[int, string]()
intStringMap.Set(42, "value")
strValue, ok := intStringMap.Get(42)
assert.True(t, ok)
assert.Equal(t, "value", strValue)
require.True(t, ok)
require.Equal(t, "value", strValue)
structMap := NewMap[string, struct{ Name string }]()
structMap.Set("key", struct{ Name string }{Name: "test"})
structValue, ok := structMap.Get("key")
assert.True(t, ok)
assert.Equal(t, "test", structValue.Name)
require.True(t, ok)
require.Equal(t, "test", structValue.Name)
}
func TestMap_InterfaceCompliance(t *testing.T) {

View File

@@ -112,15 +112,6 @@ func (s *Slice[T]) Len() int {
return len(s.inner)
}
// Slice returns a copy of the underlying slice.
func (s *Slice[T]) Slice() []T {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]T, len(s.inner))
copy(result, s.inner)
return result
}
// SetSlice replaces the entire slice with a new one.
func (s *Slice[T]) SetSlice(items []T) {
s.mu.Lock()
@@ -129,13 +120,6 @@ func (s *Slice[T]) SetSlice(items []T) {
copy(s.inner, items)
}
// Clear removes all elements from the slice.
func (s *Slice[T]) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.inner = s.inner[:0]
}
// Seq returns an iterator that yields elements from the slice.
func (s *Slice[T]) Seq() iter.Seq[T] {
return func(yield func(T) bool) {

View File

@@ -1,12 +1,12 @@
package csync
import (
"slices"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -25,7 +25,7 @@ func TestLazySlice_Seq(t *testing.T) {
result = append(result, v)
}
assert.Equal(t, data, result)
require.Equal(t, data, result)
}
func TestLazySlice_SeqWaitsForLoading(t *testing.T) {
@@ -41,15 +41,15 @@ func TestLazySlice_SeqWaitsForLoading(t *testing.T) {
return data
})
assert.False(t, loaded.Load(), "should not be loaded immediately")
require.False(t, loaded.Load(), "should not be loaded immediately")
var result []string
for v := range s.Seq() {
result = append(result, v)
}
assert.True(t, loaded.Load(), "should be loaded after Seq")
assert.Equal(t, data, result)
require.True(t, loaded.Load(), "should be loaded after Seq")
require.Equal(t, data, result)
}
func TestLazySlice_EmptySlice(t *testing.T) {
@@ -64,7 +64,7 @@ func TestLazySlice_EmptySlice(t *testing.T) {
result = append(result, v)
}
assert.Empty(t, result)
require.Empty(t, result)
}
func TestLazySlice_EarlyBreak(t *testing.T) {
@@ -85,25 +85,25 @@ func TestLazySlice_EarlyBreak(t *testing.T) {
}
}
assert.Equal(t, []string{"a", "b"}, result)
require.Equal(t, []string{"a", "b"}, result)
}
func TestSlice(t *testing.T) {
t.Run("NewSlice", func(t *testing.T) {
s := NewSlice[int]()
assert.Equal(t, 0, s.Len())
require.Equal(t, 0, s.Len())
})
t.Run("NewSliceFrom", func(t *testing.T) {
original := []int{1, 2, 3}
s := NewSliceFrom(original)
assert.Equal(t, 3, s.Len())
require.Equal(t, 3, s.Len())
// Verify it's a copy, not a reference
original[0] = 999
val, ok := s.Get(0)
require.True(t, ok)
assert.Equal(t, 1, val)
require.Equal(t, 1, val)
})
t.Run("Append", func(t *testing.T) {
@@ -111,14 +111,14 @@ func TestSlice(t *testing.T) {
s.Append("hello")
s.Append("world")
assert.Equal(t, 2, s.Len())
require.Equal(t, 2, s.Len())
val, ok := s.Get(0)
require.True(t, ok)
assert.Equal(t, "hello", val)
require.Equal(t, "hello", val)
val, ok = s.Get(1)
require.True(t, ok)
assert.Equal(t, "world", val)
require.Equal(t, "world", val)
})
t.Run("Prepend", func(t *testing.T) {
@@ -126,14 +126,14 @@ func TestSlice(t *testing.T) {
s.Append("world")
s.Prepend("hello")
assert.Equal(t, 2, s.Len())
require.Equal(t, 2, s.Len())
val, ok := s.Get(0)
require.True(t, ok)
assert.Equal(t, "hello", val)
require.Equal(t, "hello", val)
val, ok = s.Get(1)
require.True(t, ok)
assert.Equal(t, "world", val)
require.Equal(t, "world", val)
})
t.Run("Delete", func(t *testing.T) {
@@ -141,22 +141,22 @@ func TestSlice(t *testing.T) {
// Delete middle element
ok := s.Delete(2)
assert.True(t, ok)
assert.Equal(t, 4, s.Len())
require.True(t, ok)
require.Equal(t, 4, s.Len())
expected := []int{1, 2, 4, 5}
actual := s.Slice()
assert.Equal(t, expected, actual)
actual := slices.Collect(s.Seq())
require.Equal(t, expected, actual)
// Delete out of bounds
ok = s.Delete(10)
assert.False(t, ok)
assert.Equal(t, 4, s.Len())
require.False(t, ok)
require.Equal(t, 4, s.Len())
// Delete negative index
ok = s.Delete(-1)
assert.False(t, ok)
assert.Equal(t, 4, s.Len())
require.False(t, ok)
require.Equal(t, 4, s.Len())
})
t.Run("Get", func(t *testing.T) {
@@ -164,34 +164,34 @@ func TestSlice(t *testing.T) {
val, ok := s.Get(1)
require.True(t, ok)
assert.Equal(t, "b", val)
require.Equal(t, "b", val)
// Out of bounds
_, ok = s.Get(10)
assert.False(t, ok)
require.False(t, ok)
// Negative index
_, ok = s.Get(-1)
assert.False(t, ok)
require.False(t, ok)
})
t.Run("Set", func(t *testing.T) {
s := NewSliceFrom([]string{"a", "b", "c"})
ok := s.Set(1, "modified")
assert.True(t, ok)
require.True(t, ok)
val, ok := s.Get(1)
require.True(t, ok)
assert.Equal(t, "modified", val)
require.Equal(t, "modified", val)
// Out of bounds
ok = s.Set(10, "invalid")
assert.False(t, ok)
require.False(t, ok)
// Negative index
ok = s.Set(-1, "invalid")
assert.False(t, ok)
require.False(t, ok)
})
t.Run("SetSlice", func(t *testing.T) {
@@ -202,36 +202,28 @@ func TestSlice(t *testing.T) {
newItems := []int{10, 20, 30}
s.SetSlice(newItems)
assert.Equal(t, 3, s.Len())
assert.Equal(t, newItems, s.Slice())
require.Equal(t, 3, s.Len())
require.Equal(t, newItems, slices.Collect(s.Seq()))
// Verify it's a copy
newItems[0] = 999
val, ok := s.Get(0)
require.True(t, ok)
assert.Equal(t, 10, val)
})
t.Run("Clear", func(t *testing.T) {
s := NewSliceFrom([]int{1, 2, 3})
assert.Equal(t, 3, s.Len())
s.Clear()
assert.Equal(t, 0, s.Len())
require.Equal(t, 10, val)
})
t.Run("Slice", func(t *testing.T) {
original := []int{1, 2, 3}
s := NewSliceFrom(original)
copy := s.Slice()
assert.Equal(t, original, copy)
copied := slices.Collect(s.Seq())
require.Equal(t, original, copied)
// Verify it's a copy
copy[0] = 999
copied[0] = 999
val, ok := s.Get(0)
require.True(t, ok)
assert.Equal(t, 1, val)
require.Equal(t, 1, val)
})
t.Run("Seq", func(t *testing.T) {
@@ -242,7 +234,7 @@ func TestSlice(t *testing.T) {
result = append(result, v)
}
assert.Equal(t, []int{1, 2, 3}, result)
require.Equal(t, []int{1, 2, 3}, result)
})
t.Run("SeqWithIndex", func(t *testing.T) {
@@ -255,8 +247,8 @@ func TestSlice(t *testing.T) {
values = append(values, v)
}
assert.Equal(t, []int{0, 1, 2}, indices)
assert.Equal(t, []string{"a", "b", "c"}, values)
require.Equal(t, []int{0, 1, 2}, indices)
require.Equal(t, []string{"a", "b", "c"}, values)
})
t.Run("ConcurrentAccess", func(t *testing.T) {
@@ -267,22 +259,17 @@ func TestSlice(t *testing.T) {
var wg sync.WaitGroup
// Concurrent appends
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
for i := range numGoroutines {
wg.Add(2)
go func(start int) {
defer wg.Done()
for j := 0; j < itemsPerGoroutine; j++ {
for j := range itemsPerGoroutine {
s.Append(start*itemsPerGoroutine + j)
}
}(i)
}
// Concurrent reads
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < itemsPerGoroutine; j++ {
for range itemsPerGoroutine {
s.Len() // Just read the length
}
}()
@@ -291,6 +278,6 @@ func TestSlice(t *testing.T) {
wg.Wait()
// Should have all items
assert.Equal(t, numGoroutines*itemsPerGoroutine, s.Len())
require.Equal(t, numGoroutines*itemsPerGoroutine, s.Len())
})
}

View File

@@ -1,26 +1,24 @@
package env
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOsEnv_Get(t *testing.T) {
env := New()
// Test getting an existing environment variable
os.Setenv("TEST_VAR", "test_value")
defer os.Unsetenv("TEST_VAR")
t.Setenv("TEST_VAR", "test_value")
value := env.Get("TEST_VAR")
assert.Equal(t, "test_value", value)
require.Equal(t, "test_value", value)
// Test getting a non-existent environment variable
value = env.Get("NON_EXISTENT_VAR")
assert.Equal(t, "", value)
require.Equal(t, "", value)
}
func TestOsEnv_Env(t *testing.T) {
@@ -29,12 +27,12 @@ func TestOsEnv_Env(t *testing.T) {
envVars := env.Env()
// Environment should not be empty in normal circumstances
assert.NotNil(t, envVars)
assert.Greater(t, len(envVars), 0)
require.NotNil(t, envVars)
require.Greater(t, len(envVars), 0)
// Each environment variable should be in key=value format
for _, envVar := range envVars {
assert.Contains(t, envVar, "=")
require.Contains(t, envVar, "=")
}
}
@@ -45,8 +43,8 @@ func TestNewFromMap(t *testing.T) {
}
env := NewFromMap(testMap)
assert.NotNil(t, env)
assert.IsType(t, &mapEnv{}, env)
require.NotNil(t, env)
require.IsType(t, &mapEnv{}, env)
}
func TestMapEnv_Get(t *testing.T) {
@@ -58,11 +56,11 @@ func TestMapEnv_Get(t *testing.T) {
env := NewFromMap(testMap)
// Test getting existing keys
assert.Equal(t, "value1", env.Get("KEY1"))
assert.Equal(t, "value2", env.Get("KEY2"))
require.Equal(t, "value1", env.Get("KEY1"))
require.Equal(t, "value2", env.Get("KEY2"))
// Test getting non-existent key
assert.Equal(t, "", env.Get("NON_EXISTENT"))
require.Equal(t, "", env.Get("NON_EXISTENT"))
}
func TestMapEnv_Env(t *testing.T) {
@@ -75,30 +73,30 @@ func TestMapEnv_Env(t *testing.T) {
env := NewFromMap(testMap)
envVars := env.Env()
assert.Len(t, envVars, 2)
require.Len(t, envVars, 2)
// Convert to map for easier testing (order is not guaranteed)
envMap := make(map[string]string)
for _, envVar := range envVars {
parts := strings.SplitN(envVar, "=", 2)
assert.Len(t, parts, 2)
require.Len(t, parts, 2)
envMap[parts[0]] = parts[1]
}
assert.Equal(t, "value1", envMap["KEY1"])
assert.Equal(t, "value2", envMap["KEY2"])
require.Equal(t, "value1", envMap["KEY1"])
require.Equal(t, "value2", envMap["KEY2"])
})
t.Run("empty map", func(t *testing.T) {
env := NewFromMap(map[string]string{})
envVars := env.Env()
assert.Nil(t, envVars)
require.Nil(t, envVars)
})
t.Run("nil map", func(t *testing.T) {
env := NewFromMap(nil)
envVars := env.Env()
assert.Nil(t, envVars)
require.Nil(t, envVars)
})
}
@@ -111,8 +109,8 @@ func TestMapEnv_GetEmptyValue(t *testing.T) {
env := NewFromMap(testMap)
// Test that empty values are returned correctly
assert.Equal(t, "", env.Get("EMPTY_KEY"))
assert.Equal(t, "value", env.Get("NORMAL_KEY"))
require.Equal(t, "", env.Get("EMPTY_KEY"))
require.Equal(t, "value", env.Get("NORMAL_KEY"))
}
func TestMapEnv_EnvFormat(t *testing.T) {
@@ -124,7 +122,7 @@ func TestMapEnv_EnvFormat(t *testing.T) {
env := NewFromMap(testMap)
envVars := env.Env()
assert.Len(t, envVars, 2)
require.Len(t, envVars, 2)
// Check that the format is correct even with special characters
found := make(map[string]bool)
@@ -137,6 +135,6 @@ func TestMapEnv_EnvFormat(t *testing.T) {
}
}
assert.True(t, found["equals"], "Should handle values with equals signs")
assert.True(t, found["spaces"], "Should handle values with spaces")
require.True(t, found["equals"], "Should handle values with equals signs")
require.True(t, found["spaces"], "Should handle values with spaces")
}

View File

@@ -97,8 +97,7 @@ func TestProcessContextPaths(t *testing.T) {
// Test with tilde expansion (if we can create a file in home directory)
tmpDir = t.TempDir()
rollback := setHomeEnv(tmpDir)
defer rollback()
setHomeEnv(t, tmpDir)
homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt")
err = os.WriteFile(homeTestFile, []byte(testContent), 0o644)
if err == nil {
@@ -114,12 +113,11 @@ func TestProcessContextPaths(t *testing.T) {
}
}
func setHomeEnv(path string) (rollback func()) {
func setHomeEnv(tb testing.TB, path string) {
tb.Helper()
key := "HOME"
if runtime.GOOS == "windows" {
key = "USERPROFILE"
}
original := os.Getenv(key)
os.Setenv(key, path)
return func() { os.Setenv(key, original) }
tb.Setenv(key, path)
}

View File

@@ -5,13 +5,12 @@ package watcher
import "syscall"
func Ulimit() (uint64, error) {
var currentLimit uint64 = 0
var rLimit syscall.Rlimit
err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
if err != nil {
return 0, err
}
currentLimit = rLimit.Cur
currentLimit := rLimit.Cur
rLimit.Cur = rLimit.Max / 10 * 8
err = syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit)
if err != nil {

View File

@@ -4,7 +4,6 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -18,9 +17,9 @@ func TestShellPerformanceComparison(t *testing.T) {
duration := time.Since(start)
require.NoError(t, err)
assert.Equal(t, 0, exitCode)
assert.Contains(t, stdout, "hello")
assert.Empty(t, stderr)
require.Equal(t, 0, exitCode)
require.Contains(t, stdout, "hello")
require.Empty(t, stderr)
t.Logf("Quick command took: %v", duration)
}

View File

@@ -317,7 +317,7 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
if s.selectedModel == nil {
return util.ReportError(fmt.Errorf("no model selected"))
return nil
}
cfg := config.Get()

View File

@@ -73,11 +73,6 @@ type completionsCmp struct {
query string // The current filter query
}
const (
maxCompletionsWidth = 80 // Maximum width for the completions popup
minCompletionsWidth = 20 // Minimum width for the completions popup
)
func New() Completions {
completionsKeyMap := DefaultKeyMap()
keyMap := list.DefaultKeyMap()