test: improve tests (#315)

This commit is contained in:
Carlos Alexandro Becker
2025-07-28 10:55:42 -03:00
committed by GitHub
parent 60e7d043fc
commit 0506272332
9 changed files with 360 additions and 365 deletions

View File

@@ -19,7 +19,8 @@
- **Interfaces**: Define interfaces in consuming packages, keep them small and focused - **Interfaces**: Define interfaces in consuming packages, keep them small and focused
- **Structs**: Use struct embedding for composition, group related fields - **Structs**: Use struct embedding for composition, group related fields
- **Constants**: Use typed constants with iota for enums, group in const blocks - **Constants**: Use typed constants with iota for enums, group in const blocks
- **Testing**: Use testify/assert and testify/require, parallel tests with `t.Parallel()` - **Testing**: Use testify's `require` package, parallel tests with `t.Parallel()`,
`t.SetEnv()` to set environment variables.
- **JSON tags**: Use snake_case for JSON field names - **JSON tags**: Use snake_case for JSON field names
- **File permissions**: Use octal notation (0o755, 0o644) for file permissions - **File permissions**: Use octal notation (0o755, 0o644) for file permissions
- **Comments**: End comments in periods unless comments are at the end of the line. - **Comments**: End comments in periods unless comments are at the end of the line.

View File

@@ -11,7 +11,7 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@@ -28,12 +28,12 @@ func TestConfig_LoadFromReaders(t *testing.T) {
loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3}) loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, loadedConfig) require.NotNil(t, loadedConfig)
assert.Equal(t, 1, loadedConfig.Providers.Len()) require.Equal(t, 1, loadedConfig.Providers.Len())
pc, _ := loadedConfig.Providers.Get("openai") pc, _ := loadedConfig.Providers.Get("openai")
assert.Equal(t, "key2", pc.APIKey) require.Equal(t, "key2", pc.APIKey)
assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
} }
func TestConfig_setDefaults(t *testing.T) { func TestConfig_setDefaults(t *testing.T) {
@@ -41,18 +41,18 @@ func TestConfig_setDefaults(t *testing.T) {
cfg.setDefaults("/tmp") cfg.setDefaults("/tmp")
assert.NotNil(t, cfg.Options) require.NotNil(t, cfg.Options)
assert.NotNil(t, cfg.Options.TUI) require.NotNil(t, cfg.Options.TUI)
assert.NotNil(t, cfg.Options.ContextPaths) require.NotNil(t, cfg.Options.ContextPaths)
assert.NotNil(t, cfg.Providers) require.NotNil(t, cfg.Providers)
assert.NotNil(t, cfg.Models) require.NotNil(t, cfg.Models)
assert.NotNil(t, cfg.LSP) require.NotNil(t, cfg.LSP)
assert.NotNil(t, cfg.MCP) require.NotNil(t, cfg.MCP)
assert.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory) require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
for _, path := range defaultContextPaths { for _, path := range defaultContextPaths {
assert.Contains(t, cfg.Options.ContextPaths, path) require.Contains(t, cfg.Options.ContextPaths, path)
} }
assert.Equal(t, "/tmp", cfg.workingDir) require.Equal(t, "/tmp", cfg.workingDir)
} }
func TestConfig_configureProviders(t *testing.T) { func TestConfig_configureProviders(t *testing.T) {
@@ -74,12 +74,12 @@ func TestConfig_configureProviders(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, cfg.Providers.Len()) require.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder // We want to make sure that we keep the configured API key as a placeholder
pc, _ := cfg.Providers.Get("openai") pc, _ := cfg.Providers.Get("openai")
assert.Equal(t, "$OPENAI_API_KEY", pc.APIKey) require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
} }
func TestConfig_configureProvidersWithOverride(t *testing.T) { func TestConfig_configureProvidersWithOverride(t *testing.T) {
@@ -117,15 +117,15 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, cfg.Providers.Len()) require.Equal(t, 1, cfg.Providers.Len())
// We want to make sure that we keep the configured API key as a placeholder // We want to make sure that we keep the configured API key as a placeholder
pc, _ := cfg.Providers.Get("openai") pc, _ := cfg.Providers.Get("openai")
assert.Equal(t, "xyz", pc.APIKey) require.Equal(t, "xyz", pc.APIKey)
assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL) require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
assert.Len(t, pc.Models, 2) require.Len(t, pc.Models, 2)
assert.Equal(t, "Updated", pc.Models[0].Name) require.Equal(t, "Updated", pc.Models[0].Name)
} }
func TestConfig_configureProvidersWithNewProvider(t *testing.T) { func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
@@ -159,20 +159,20 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
// Should be to because of the env variable // Should be to because of the env variable
assert.Equal(t, cfg.Providers.Len(), 2) require.Equal(t, cfg.Providers.Len(), 2)
// We want to make sure that we keep the configured API key as a placeholder // We want to make sure that we keep the configured API key as a placeholder
pc, _ := cfg.Providers.Get("custom") pc, _ := cfg.Providers.Get("custom")
assert.Equal(t, "xyz", pc.APIKey) require.Equal(t, "xyz", pc.APIKey)
// Make sure we set the ID correctly // Make sure we set the ID correctly
assert.Equal(t, "custom", pc.ID) require.Equal(t, "custom", pc.ID)
assert.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL) require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
assert.Len(t, pc.Models, 1) require.Len(t, pc.Models, 1)
_, ok := cfg.Providers.Get("openai") _, ok := cfg.Providers.Get("openai")
assert.True(t, ok, "OpenAI provider should still be present") require.True(t, ok, "OpenAI provider should still be present")
} }
func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
@@ -195,13 +195,13 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1) require.Equal(t, cfg.Providers.Len(), 1)
bedrockProvider, ok := cfg.Providers.Get("bedrock") bedrockProvider, ok := cfg.Providers.Get("bedrock")
assert.True(t, ok, "Bedrock provider should be present") require.True(t, ok, "Bedrock provider should be present")
assert.Len(t, bedrockProvider.Models, 1) require.Len(t, bedrockProvider.Models, 1)
assert.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID) require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
} }
func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
@@ -221,9 +221,9 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
// Provider should not be configured without credentials // Provider should not be configured without credentials
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
} }
func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
@@ -246,7 +246,7 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.Error(t, err) require.Error(t, err)
} }
func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
@@ -270,15 +270,15 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1) require.Equal(t, cfg.Providers.Len(), 1)
vertexProvider, ok := cfg.Providers.Get("vertexai") vertexProvider, ok := cfg.Providers.Get("vertexai")
assert.True(t, ok, "VertexAI provider should be present") require.True(t, ok, "VertexAI provider should be present")
assert.Len(t, vertexProvider.Models, 1) require.Len(t, vertexProvider.Models, 1)
assert.Equal(t, "gemini-pro", vertexProvider.Models[0].ID) require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
assert.Equal(t, "test-project", vertexProvider.ExtraParams["project"]) require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
assert.Equal(t, "us-central1", vertexProvider.ExtraParams["location"]) require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
} }
func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
@@ -302,9 +302,9 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
// Provider should not be configured without proper credentials // Provider should not be configured without proper credentials
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
} }
func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
@@ -327,9 +327,9 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
// Provider should not be configured without project // Provider should not be configured without project
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
} }
func TestConfig_configureProvidersSetProviderID(t *testing.T) { func TestConfig_configureProvidersSetProviderID(t *testing.T) {
@@ -351,12 +351,12 @@ func TestConfig_configureProvidersSetProviderID(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1) require.Equal(t, cfg.Providers.Len(), 1)
// Provider ID should be set // Provider ID should be set
pc, _ := cfg.Providers.Get("openai") pc, _ := cfg.Providers.Get("openai")
assert.Equal(t, "openai", pc.ID) require.Equal(t, "openai", pc.ID)
} }
func TestConfig_EnabledProviders(t *testing.T) { func TestConfig_EnabledProviders(t *testing.T) {
@@ -377,7 +377,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
} }
enabled := cfg.EnabledProviders() enabled := cfg.EnabledProviders()
assert.Len(t, enabled, 2) require.Len(t, enabled, 2)
}) })
t.Run("some providers disabled", func(t *testing.T) { t.Run("some providers disabled", func(t *testing.T) {
@@ -397,8 +397,8 @@ func TestConfig_EnabledProviders(t *testing.T) {
} }
enabled := cfg.EnabledProviders() enabled := cfg.EnabledProviders()
assert.Len(t, enabled, 1) require.Len(t, enabled, 1)
assert.Equal(t, "openai", enabled[0].ID) require.Equal(t, "openai", enabled[0].ID)
}) })
t.Run("empty providers map", func(t *testing.T) { t.Run("empty providers map", func(t *testing.T) {
@@ -407,7 +407,7 @@ func TestConfig_EnabledProviders(t *testing.T) {
} }
enabled := cfg.EnabledProviders() enabled := cfg.EnabledProviders()
assert.Len(t, enabled, 0) require.Len(t, enabled, 0)
}) })
} }
@@ -423,7 +423,7 @@ func TestConfig_IsConfigured(t *testing.T) {
}), }),
} }
assert.True(t, cfg.IsConfigured()) require.True(t, cfg.IsConfigured())
}) })
t.Run("returns false when no providers are configured", func(t *testing.T) { t.Run("returns false when no providers are configured", func(t *testing.T) {
@@ -431,7 +431,7 @@ func TestConfig_IsConfigured(t *testing.T) {
Providers: csync.NewMap[string, ProviderConfig](), Providers: csync.NewMap[string, ProviderConfig](),
} }
assert.False(t, cfg.IsConfigured()) require.False(t, cfg.IsConfigured())
}) })
t.Run("returns false when all providers are disabled", func(t *testing.T) { t.Run("returns false when all providers are disabled", func(t *testing.T) {
@@ -450,7 +450,7 @@ func TestConfig_IsConfigured(t *testing.T) {
}), }),
} }
assert.False(t, cfg.IsConfigured()) require.False(t, cfg.IsConfigured())
}) })
} }
@@ -480,12 +480,12 @@ func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
// Provider should be removed from config when disabled // Provider should be removed from config when disabled
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("openai") _, exists := cfg.Providers.Get("openai")
assert.False(t, exists) require.False(t, exists)
} }
func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
@@ -508,11 +508,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1) require.Equal(t, cfg.Providers.Len(), 1)
_, exists := cfg.Providers.Get("custom") _, exists := cfg.Providers.Get("custom")
assert.True(t, exists) require.True(t, exists)
}) })
t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
@@ -531,11 +531,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom") _, exists := cfg.Providers.Get("custom")
assert.False(t, exists) require.False(t, exists)
}) })
t.Run("custom provider with no models is removed", func(t *testing.T) { t.Run("custom provider with no models is removed", func(t *testing.T) {
@@ -553,11 +553,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom") _, exists := cfg.Providers.Get("custom")
assert.False(t, exists) require.False(t, exists)
}) })
t.Run("custom provider with unsupported type is removed", func(t *testing.T) { t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
@@ -578,11 +578,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom") _, exists := cfg.Providers.Get("custom")
assert.False(t, exists) require.False(t, exists)
}) })
t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
@@ -603,14 +603,14 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1) require.Equal(t, cfg.Providers.Len(), 1)
customProvider, exists := cfg.Providers.Get("custom") customProvider, exists := cfg.Providers.Get("custom")
assert.True(t, exists) require.True(t, exists)
assert.Equal(t, "custom", customProvider.ID) require.Equal(t, "custom", customProvider.ID)
assert.Equal(t, "test-key", customProvider.APIKey) require.Equal(t, "test-key", customProvider.APIKey)
assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL) require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
}) })
t.Run("custom anthropic provider is supported", func(t *testing.T) { t.Run("custom anthropic provider is supported", func(t *testing.T) {
@@ -631,15 +631,15 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1) require.Equal(t, cfg.Providers.Len(), 1)
customProvider, exists := cfg.Providers.Get("custom-anthropic") customProvider, exists := cfg.Providers.Get("custom-anthropic")
assert.True(t, exists) require.True(t, exists)
assert.Equal(t, "custom-anthropic", customProvider.ID) require.Equal(t, "custom-anthropic", customProvider.ID)
assert.Equal(t, "test-key", customProvider.APIKey) require.Equal(t, "test-key", customProvider.APIKey)
assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL) require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type) require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
}) })
t.Run("disabled custom provider is removed", func(t *testing.T) { t.Run("disabled custom provider is removed", func(t *testing.T) {
@@ -661,11 +661,11 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("custom") _, exists := cfg.Providers.Get("custom")
assert.False(t, exists) require.False(t, exists)
}) })
} }
@@ -696,11 +696,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("vertexai") _, exists := cfg.Providers.Get("vertexai")
assert.False(t, exists) require.False(t, exists)
}) })
t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) { t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
@@ -727,11 +727,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("bedrock") _, exists := cfg.Providers.Get("bedrock")
assert.False(t, exists) require.False(t, exists)
}) })
t.Run("provider removed when API key missing with existing config", func(t *testing.T) { t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
@@ -758,11 +758,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 0) require.Equal(t, cfg.Providers.Len(), 0)
_, exists := cfg.Providers.Get("openai") _, exists := cfg.Providers.Get("openai")
assert.False(t, exists) require.False(t, exists)
}) })
t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) { t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
@@ -791,11 +791,11 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
}) })
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, cfg.Providers.Len(), 1) require.Equal(t, cfg.Providers.Len(), 1)
_, exists := cfg.Providers.Get("openai") _, exists := cfg.Providers.Get("openai")
assert.True(t, exists) require.True(t, exists)
}) })
} }
@@ -825,16 +825,16 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders) large, small, err := cfg.defaultModelSelection(knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "large-model", large.Model) require.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider) require.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(1000), large.MaxTokens) require.Equal(t, int64(1000), large.MaxTokens)
assert.Equal(t, "small-model", small.Model) require.Equal(t, "small-model", small.Model)
assert.Equal(t, "openai", small.Provider) require.Equal(t, "openai", small.Provider)
assert.Equal(t, int64(500), small.MaxTokens) require.Equal(t, int64(500), small.MaxTokens)
}) })
t.Run("should error if no providers configured", func(t *testing.T) { t.Run("should error if no providers configured", func(t *testing.T) {
knownProviders := []catwalk.Provider{ knownProviders := []catwalk.Provider{
@@ -861,10 +861,10 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders) _, _, err = cfg.defaultModelSelection(knownProviders)
assert.Error(t, err) require.Error(t, err)
}) })
t.Run("should error if model is missing", func(t *testing.T) { t.Run("should error if model is missing", func(t *testing.T) {
knownProviders := []catwalk.Provider{ knownProviders := []catwalk.Provider{
@@ -891,9 +891,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders) _, _, err = cfg.defaultModelSelection(knownProviders)
assert.Error(t, err) require.Error(t, err)
}) })
t.Run("should configure the default models with a custom provider", func(t *testing.T) { t.Run("should configure the default models with a custom provider", func(t *testing.T) {
@@ -934,15 +934,15 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders) large, small, err := cfg.defaultModelSelection(knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "model", large.Model) require.Equal(t, "model", large.Model)
assert.Equal(t, "custom", large.Provider) require.Equal(t, "custom", large.Provider)
assert.Equal(t, int64(600), large.MaxTokens) require.Equal(t, int64(600), large.MaxTokens)
assert.Equal(t, "model", small.Model) require.Equal(t, "model", small.Model)
assert.Equal(t, "custom", small.Provider) require.Equal(t, "custom", small.Provider)
assert.Equal(t, int64(600), small.MaxTokens) require.Equal(t, int64(600), small.MaxTokens)
}) })
t.Run("should fail if no model configured", func(t *testing.T) { t.Run("should fail if no model configured", func(t *testing.T) {
@@ -978,9 +978,9 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
_, _, err = cfg.defaultModelSelection(knownProviders) _, _, err = cfg.defaultModelSelection(knownProviders)
assert.Error(t, err) require.Error(t, err)
}) })
t.Run("should use the default provider first", func(t *testing.T) { t.Run("should use the default provider first", func(t *testing.T) {
knownProviders := []catwalk.Provider{ knownProviders := []catwalk.Provider{
@@ -1020,15 +1020,15 @@ func TestConfig_defaultModelSelection(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
large, small, err := cfg.defaultModelSelection(knownProviders) large, small, err := cfg.defaultModelSelection(knownProviders)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "large-model", large.Model) require.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider) require.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(1000), large.MaxTokens) require.Equal(t, int64(1000), large.MaxTokens)
assert.Equal(t, "small-model", small.Model) require.Equal(t, "small-model", small.Model)
assert.Equal(t, "openai", small.Provider) require.Equal(t, "openai", small.Provider)
assert.Equal(t, int64(500), small.MaxTokens) require.Equal(t, int64(500), small.MaxTokens)
}) })
} }
@@ -1068,18 +1068,18 @@ func TestConfig_configureSelectedModels(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
err = cfg.configureSelectedModels(knownProviders) err = cfg.configureSelectedModels(knownProviders)
assert.NoError(t, err) require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge] large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall] small := cfg.Models[SelectedModelTypeSmall]
assert.Equal(t, "larger-model", large.Model) require.Equal(t, "larger-model", large.Model)
assert.Equal(t, "openai", large.Provider) require.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(2000), large.MaxTokens) require.Equal(t, int64(2000), large.MaxTokens)
assert.Equal(t, "small-model", small.Model) require.Equal(t, "small-model", small.Model)
assert.Equal(t, "openai", small.Provider) require.Equal(t, "openai", small.Provider)
assert.Equal(t, int64(500), small.MaxTokens) require.Equal(t, int64(500), small.MaxTokens)
}) })
t.Run("should be possible to use multiple providers", func(t *testing.T) { t.Run("should be possible to use multiple providers", func(t *testing.T) {
knownProviders := []catwalk.Provider{ knownProviders := []catwalk.Provider{
@@ -1130,18 +1130,18 @@ func TestConfig_configureSelectedModels(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
err = cfg.configureSelectedModels(knownProviders) err = cfg.configureSelectedModels(knownProviders)
assert.NoError(t, err) require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge] large := cfg.Models[SelectedModelTypeLarge]
small := cfg.Models[SelectedModelTypeSmall] small := cfg.Models[SelectedModelTypeSmall]
assert.Equal(t, "large-model", large.Model) require.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider) require.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(1000), large.MaxTokens) require.Equal(t, int64(1000), large.MaxTokens)
assert.Equal(t, "a-small-model", small.Model) require.Equal(t, "a-small-model", small.Model)
assert.Equal(t, "anthropic", small.Provider) require.Equal(t, "anthropic", small.Provider)
assert.Equal(t, int64(300), small.MaxTokens) require.Equal(t, int64(300), small.MaxTokens)
}) })
t.Run("should override the max tokens only", func(t *testing.T) { t.Run("should override the max tokens only", func(t *testing.T) {
@@ -1175,13 +1175,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
env := env.NewFromMap(map[string]string{}) env := env.NewFromMap(map[string]string{})
resolver := NewEnvironmentVariableResolver(env) resolver := NewEnvironmentVariableResolver(env)
err := cfg.configureProviders(env, resolver, knownProviders) err := cfg.configureProviders(env, resolver, knownProviders)
assert.NoError(t, err) require.NoError(t, err)
err = cfg.configureSelectedModels(knownProviders) err = cfg.configureSelectedModels(knownProviders)
assert.NoError(t, err) require.NoError(t, err)
large := cfg.Models[SelectedModelTypeLarge] large := cfg.Models[SelectedModelTypeLarge]
assert.Equal(t, "large-model", large.Model) require.Equal(t, "large-model", large.Model)
assert.Equal(t, "openai", large.Provider) require.Equal(t, "openai", large.Provider)
assert.Equal(t, int64(100), large.MaxTokens) require.Equal(t, int64(100), large.MaxTokens)
}) })
} }

View File

@@ -7,7 +7,7 @@ import (
"testing" "testing"
"github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
type mockProviderClient struct { type mockProviderClient struct {
@@ -29,14 +29,14 @@ func TestProvider_loadProvidersNoIssues(t *testing.T) {
client := &mockProviderClient{shouldFail: false} client := &mockProviderClient{shouldFail: false}
tmpPath := t.TempDir() + "/providers.json" tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(client, tmpPath) providers, err := loadProviders(client, tmpPath)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, providers) require.NotNil(t, providers)
assert.Len(t, providers, 1) require.Len(t, providers, 1)
// check if file got saved // check if file got saved
fileInfo, err := os.Stat(tmpPath) fileInfo, err := os.Stat(tmpPath)
assert.NoError(t, err) require.NoError(t, err)
assert.False(t, fileInfo.IsDir(), "Expected a file, not a directory") require.False(t, fileInfo.IsDir(), "Expected a file, not a directory")
} }
func TestProvider_loadProvidersWithIssues(t *testing.T) { func TestProvider_loadProvidersWithIssues(t *testing.T) {
@@ -58,16 +58,16 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
t.Fatalf("Failed to write old providers to file: %v", err) t.Fatalf("Failed to write old providers to file: %v", err)
} }
providers, err := loadProviders(client, tmpPath) providers, err := loadProviders(client, tmpPath)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, providers) require.NotNil(t, providers)
assert.Len(t, providers, 1) require.Len(t, providers, 1)
assert.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails") require.Equal(t, "OldProvider", providers[0].Name, "Expected to keep old provider when loading fails")
} }
func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) { func TestProvider_loadProvidersWithIssuesAndNoCache(t *testing.T) {
client := &mockProviderClient{shouldFail: true} client := &mockProviderClient{shouldFail: true}
tmpPath := t.TempDir() + "/providers.json" tmpPath := t.TempDir() + "/providers.json"
providers, err := loadProviders(client, tmpPath) providers, err := loadProviders(client, tmpPath)
assert.Error(t, err) require.Error(t, err)
assert.Nil(t, providers, "Expected nil providers when loading fails and no cache exists") require.Nil(t, providers, "Expected nil providers when loading fails and no cache exists")
} }

View File

@@ -6,7 +6,7 @@ import (
"testing" "testing"
"github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/env"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
// mockShell implements the Shell interface for testing // mockShell implements the Shell interface for testing
@@ -85,10 +85,10 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) {
result, err := resolver.ResolveValue(tt.value) result, err := resolver.ResolveValue(tt.value)
if tt.expectError { if tt.expectError {
assert.Error(t, err) require.Error(t, err)
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.expected, result) require.Equal(t, tt.expected, result)
} }
}) })
} }
@@ -250,10 +250,10 @@ func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) {
result, err := resolver.ResolveValue(tt.value) result, err := resolver.ResolveValue(tt.value)
if tt.expectError { if tt.expectError {
assert.Error(t, err) require.Error(t, err)
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.expected, result) require.Equal(t, tt.expected, result)
} }
}) })
} }
@@ -306,10 +306,10 @@ func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) {
result, err := resolver.ResolveValue(tt.value) result, err := resolver.ResolveValue(tt.value)
if tt.expectError { if tt.expectError {
assert.Error(t, err) require.Error(t, err)
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.expected, result) require.Equal(t, tt.expected, result)
} }
}) })
} }
@@ -319,14 +319,14 @@ func TestNewShellVariableResolver(t *testing.T) {
testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
resolver := NewShellVariableResolver(testEnv) resolver := NewShellVariableResolver(testEnv)
assert.NotNil(t, resolver) require.NotNil(t, resolver)
assert.Implements(t, (*VariableResolver)(nil), resolver) require.Implements(t, (*VariableResolver)(nil), resolver)
} }
func TestNewEnvironmentVariableResolver(t *testing.T) { func TestNewEnvironmentVariableResolver(t *testing.T) {
testEnv := env.NewFromMap(map[string]string{"TEST": "value"}) testEnv := env.NewFromMap(map[string]string{"TEST": "value"})
resolver := NewEnvironmentVariableResolver(testEnv) resolver := NewEnvironmentVariableResolver(testEnv)
assert.NotNil(t, resolver) require.NotNil(t, resolver)
assert.Implements(t, (*VariableResolver)(nil), resolver) require.Implements(t, (*VariableResolver)(nil), resolver)
} }

View File

@@ -6,16 +6,16 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestNewMap(t *testing.T) { func TestNewMap(t *testing.T) {
t.Parallel() t.Parallel()
m := NewMap[string, int]() m := NewMap[string, int]()
assert.NotNil(t, m) require.NotNil(t, m)
assert.NotNil(t, m.inner) require.NotNil(t, m.inner)
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
} }
func TestNewMapFrom(t *testing.T) { func TestNewMapFrom(t *testing.T) {
@@ -27,13 +27,13 @@ func TestNewMapFrom(t *testing.T) {
} }
m := NewMapFrom(original) m := NewMapFrom(original)
assert.NotNil(t, m) require.NotNil(t, m)
assert.Equal(t, original, m.inner) require.Equal(t, original, m.inner)
assert.Equal(t, 2, m.Len()) require.Equal(t, 2, m.Len())
value, ok := m.Get("key1") value, ok := m.Get("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 1, value) require.Equal(t, 1, value)
} }
func TestMap_Set(t *testing.T) { func TestMap_Set(t *testing.T) {
@@ -43,15 +43,15 @@ func TestMap_Set(t *testing.T) {
m.Set("key1", 42) m.Set("key1", 42)
value, ok := m.Get("key1") value, ok := m.Get("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 42, value) require.Equal(t, 42, value)
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
m.Set("key1", 100) m.Set("key1", 100)
value, ok = m.Get("key1") value, ok = m.Get("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 100, value) require.Equal(t, 100, value)
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
} }
func TestMap_Get(t *testing.T) { func TestMap_Get(t *testing.T) {
@@ -60,13 +60,13 @@ func TestMap_Get(t *testing.T) {
m := NewMap[string, int]() m := NewMap[string, int]()
value, ok := m.Get("nonexistent") value, ok := m.Get("nonexistent")
assert.False(t, ok) require.False(t, ok)
assert.Equal(t, 0, value) require.Equal(t, 0, value)
m.Set("key1", 42) m.Set("key1", 42)
value, ok = m.Get("key1") value, ok = m.Get("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 42, value) require.Equal(t, 42, value)
} }
func TestMap_Del(t *testing.T) { func TestMap_Del(t *testing.T) {
@@ -76,38 +76,38 @@ func TestMap_Del(t *testing.T) {
m.Set("key1", 42) m.Set("key1", 42)
m.Set("key2", 100) m.Set("key2", 100)
assert.Equal(t, 2, m.Len()) require.Equal(t, 2, m.Len())
m.Del("key1") m.Del("key1")
_, ok := m.Get("key1") _, ok := m.Get("key1")
assert.False(t, ok) require.False(t, ok)
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
value, ok := m.Get("key2") value, ok := m.Get("key2")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 100, value) require.Equal(t, 100, value)
m.Del("nonexistent") m.Del("nonexistent")
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
} }
func TestMap_Len(t *testing.T) { func TestMap_Len(t *testing.T) {
t.Parallel() t.Parallel()
m := NewMap[string, int]() m := NewMap[string, int]()
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
m.Set("key1", 1) m.Set("key1", 1)
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
m.Set("key2", 2) m.Set("key2", 2)
assert.Equal(t, 2, m.Len()) require.Equal(t, 2, m.Len())
m.Del("key1") m.Del("key1")
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
m.Del("key2") m.Del("key2")
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
} }
func TestMap_Take(t *testing.T) { func TestMap_Take(t *testing.T) {
@@ -117,19 +117,19 @@ func TestMap_Take(t *testing.T) {
m.Set("key1", 42) m.Set("key1", 42)
m.Set("key2", 100) m.Set("key2", 100)
assert.Equal(t, 2, m.Len()) require.Equal(t, 2, m.Len())
value, ok := m.Take("key1") value, ok := m.Take("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 42, value) require.Equal(t, 42, value)
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
_, exists := m.Get("key1") _, exists := m.Get("key1")
assert.False(t, exists) require.False(t, exists)
value, ok = m.Get("key2") value, ok = m.Get("key2")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 100, value) require.Equal(t, 100, value)
} }
func TestMap_Take_NonexistentKey(t *testing.T) { func TestMap_Take_NonexistentKey(t *testing.T) {
@@ -139,13 +139,13 @@ func TestMap_Take_NonexistentKey(t *testing.T) {
m.Set("key1", 42) m.Set("key1", 42)
value, ok := m.Take("nonexistent") value, ok := m.Take("nonexistent")
assert.False(t, ok) require.False(t, ok)
assert.Equal(t, 0, value) require.Equal(t, 0, value)
assert.Equal(t, 1, m.Len()) require.Equal(t, 1, m.Len())
value, ok = m.Get("key1") value, ok = m.Get("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 42, value) require.Equal(t, 42, value)
} }
func TestMap_Take_EmptyMap(t *testing.T) { func TestMap_Take_EmptyMap(t *testing.T) {
@@ -154,9 +154,9 @@ func TestMap_Take_EmptyMap(t *testing.T) {
m := NewMap[string, int]() m := NewMap[string, int]()
value, ok := m.Take("key1") value, ok := m.Take("key1")
assert.False(t, ok) require.False(t, ok)
assert.Equal(t, 0, value) require.Equal(t, 0, value)
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
} }
func TestMap_Take_SameKeyTwice(t *testing.T) { func TestMap_Take_SameKeyTwice(t *testing.T) {
@@ -166,14 +166,14 @@ func TestMap_Take_SameKeyTwice(t *testing.T) {
m.Set("key1", 42) m.Set("key1", 42)
value, ok := m.Take("key1") value, ok := m.Take("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 42, value) require.Equal(t, 42, value)
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
value, ok = m.Take("key1") value, ok = m.Take("key1")
assert.False(t, ok) require.False(t, ok)
assert.Equal(t, 0, value) require.Equal(t, 0, value)
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
} }
func TestMap_Seq2(t *testing.T) { func TestMap_Seq2(t *testing.T) {
@@ -186,10 +186,10 @@ func TestMap_Seq2(t *testing.T) {
collected := maps.Collect(m.Seq2()) collected := maps.Collect(m.Seq2())
assert.Equal(t, 3, len(collected)) require.Equal(t, 3, len(collected))
assert.Equal(t, 1, collected["key1"]) require.Equal(t, 1, collected["key1"])
assert.Equal(t, 2, collected["key2"]) require.Equal(t, 2, collected["key2"])
assert.Equal(t, 3, collected["key3"]) require.Equal(t, 3, collected["key3"])
} }
func TestMap_Seq2_EarlyReturn(t *testing.T) { func TestMap_Seq2_EarlyReturn(t *testing.T) {
@@ -208,7 +208,7 @@ func TestMap_Seq2_EarlyReturn(t *testing.T) {
} }
} }
assert.Equal(t, 2, count) require.Equal(t, 2, count)
} }
func TestMap_Seq2_EmptyMap(t *testing.T) { func TestMap_Seq2_EmptyMap(t *testing.T) {
@@ -221,7 +221,7 @@ func TestMap_Seq2_EmptyMap(t *testing.T) {
count++ count++
} }
assert.Equal(t, 0, count) require.Equal(t, 0, count)
} }
func TestMap_Seq(t *testing.T) { func TestMap_Seq(t *testing.T) {
@@ -237,10 +237,10 @@ func TestMap_Seq(t *testing.T) {
collected = append(collected, v) collected = append(collected, v)
} }
assert.Equal(t, 3, len(collected)) require.Equal(t, 3, len(collected))
assert.Contains(t, collected, 1) require.Contains(t, collected, 1)
assert.Contains(t, collected, 2) require.Contains(t, collected, 2)
assert.Contains(t, collected, 3) require.Contains(t, collected, 3)
} }
func TestMap_Seq_EarlyReturn(t *testing.T) { func TestMap_Seq_EarlyReturn(t *testing.T) {
@@ -259,7 +259,7 @@ func TestMap_Seq_EarlyReturn(t *testing.T) {
} }
} }
assert.Equal(t, 2, count) require.Equal(t, 2, count)
} }
func TestMap_Seq_EmptyMap(t *testing.T) { func TestMap_Seq_EmptyMap(t *testing.T) {
@@ -272,7 +272,7 @@ func TestMap_Seq_EmptyMap(t *testing.T) {
count++ count++
} }
assert.Equal(t, 0, count) require.Equal(t, 0, count)
} }
func TestMap_MarshalJSON(t *testing.T) { func TestMap_MarshalJSON(t *testing.T) {
@@ -283,16 +283,16 @@ func TestMap_MarshalJSON(t *testing.T) {
m.Set("key2", 2) m.Set("key2", 2)
data, err := json.Marshal(m) data, err := json.Marshal(m)
assert.NoError(t, err) require.NoError(t, err)
result := &Map[string, int]{} result := &Map[string, int]{}
err = json.Unmarshal(data, result) err = json.Unmarshal(data, result)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, result.Len()) require.Equal(t, 2, result.Len())
v1, _ := result.Get("key1") v1, _ := result.Get("key1")
v2, _ := result.Get("key2") v2, _ := result.Get("key2")
assert.Equal(t, 1, v1) require.Equal(t, 1, v1)
assert.Equal(t, 2, v2) require.Equal(t, 2, v2)
} }
func TestMap_MarshalJSON_EmptyMap(t *testing.T) { func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
@@ -301,8 +301,8 @@ func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
m := NewMap[string, int]() m := NewMap[string, int]()
data, err := json.Marshal(m) data, err := json.Marshal(m)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "{}", string(data)) require.Equal(t, "{}", string(data))
} }
func TestMap_UnmarshalJSON(t *testing.T) { func TestMap_UnmarshalJSON(t *testing.T) {
@@ -312,16 +312,16 @@ func TestMap_UnmarshalJSON(t *testing.T) {
m := NewMap[string, int]() m := NewMap[string, int]()
err := json.Unmarshal([]byte(jsonData), m) err := json.Unmarshal([]byte(jsonData), m)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, m.Len()) require.Equal(t, 2, m.Len())
value, ok := m.Get("key1") value, ok := m.Get("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 1, value) require.Equal(t, 1, value)
value, ok = m.Get("key2") value, ok = m.Get("key2")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 2, value) require.Equal(t, 2, value)
} }
func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) { func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
@@ -331,8 +331,8 @@ func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
m := NewMap[string, int]() m := NewMap[string, int]()
err := json.Unmarshal([]byte(jsonData), m) err := json.Unmarshal([]byte(jsonData), m)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
} }
func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) { func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
@@ -342,7 +342,7 @@ func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
m := NewMap[string, int]() m := NewMap[string, int]()
err := json.Unmarshal([]byte(jsonData), m) err := json.Unmarshal([]byte(jsonData), m)
assert.Error(t, err) require.Error(t, err)
} }
func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) { func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
@@ -353,15 +353,15 @@ func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
jsonData := `{"key1": 1, "key2": 2}` jsonData := `{"key1": 1, "key2": 2}`
err := json.Unmarshal([]byte(jsonData), m) err := json.Unmarshal([]byte(jsonData), m)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, m.Len()) require.Equal(t, 2, m.Len())
_, ok := m.Get("existing") _, ok := m.Get("existing")
assert.False(t, ok) require.False(t, ok)
value, ok := m.Get("key1") value, ok := m.Get("key1")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 1, value) require.Equal(t, 1, value)
} }
func TestMap_JSONRoundTrip(t *testing.T) { func TestMap_JSONRoundTrip(t *testing.T) {
@@ -373,18 +373,18 @@ func TestMap_JSONRoundTrip(t *testing.T) {
original.Set("key3", 3) original.Set("key3", 3)
data, err := json.Marshal(original) data, err := json.Marshal(original)
assert.NoError(t, err) require.NoError(t, err)
restored := NewMap[string, int]() restored := NewMap[string, int]()
err = json.Unmarshal(data, restored) err = json.Unmarshal(data, restored)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, original.Len(), restored.Len()) require.Equal(t, original.Len(), restored.Len())
for k, v := range original.Seq2() { for k, v := range original.Seq2() {
restoredValue, ok := restored.Get(k) restoredValue, ok := restored.Get(k)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, v, restoredValue) require.Equal(t, v, restoredValue)
} }
} }
@@ -405,15 +405,15 @@ func TestMap_ConcurrentAccess(t *testing.T) {
key := id*numOperations + j key := id*numOperations + j
m.Set(key, key*2) m.Set(key, key*2)
value, ok := m.Get(key) value, ok := m.Get(key)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, key*2, value) require.Equal(t, key*2, value)
} }
}(i) }(i)
} }
wg.Wait() wg.Wait()
assert.Equal(t, numGoroutines*numOperations, m.Len()) require.Equal(t, numGoroutines*numOperations, m.Len())
} }
func TestMap_ConcurrentReadWrite(t *testing.T) { func TestMap_ConcurrentReadWrite(t *testing.T) {
@@ -438,7 +438,7 @@ func TestMap_ConcurrentReadWrite(t *testing.T) {
key := j % 1000 key := j % 1000
value, ok := m.Get(key) value, ok := m.Get(key)
if ok { if ok {
assert.Equal(t, key, value) require.Equal(t, key, value)
} }
_ = m.Len() _ = m.Len()
} }
@@ -478,10 +478,10 @@ func TestMap_ConcurrentSeq2(t *testing.T) {
defer wg.Done() defer wg.Done()
count := 0 count := 0
for k, v := range m.Seq2() { for k, v := range m.Seq2() {
assert.Equal(t, k*2, v) require.Equal(t, k*2, v)
count++ count++
} }
assert.Equal(t, 100, count) require.Equal(t, 100, count)
}() }()
} }
@@ -509,9 +509,9 @@ func TestMap_ConcurrentSeq(t *testing.T) {
values[v] = true values[v] = true
count++ count++
} }
assert.Equal(t, 100, count) require.Equal(t, 100, count)
for i := range 100 { for i := range 100 {
assert.True(t, values[i*2]) require.True(t, values[i*2])
} }
}() }()
} }
@@ -548,19 +548,19 @@ func TestMap_ConcurrentTake(t *testing.T) {
wg.Wait() wg.Wait()
assert.Equal(t, 0, m.Len()) require.Equal(t, 0, m.Len())
allTaken := make(map[int]bool) allTaken := make(map[int]bool)
for _, workerTaken := range taken { for _, workerTaken := range taken {
for _, value := range workerTaken { for _, value := range workerTaken {
assert.False(t, allTaken[value], "Value %d was taken multiple times", value) require.False(t, allTaken[value], "Value %d was taken multiple times", value)
allTaken[value] = true allTaken[value] = true
} }
} }
assert.Equal(t, numItems, len(allTaken)) require.Equal(t, numItems, len(allTaken))
for i := range numItems { for i := range numItems {
assert.True(t, allTaken[i*2], "Expected value %d to be taken", i*2) require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
} }
} }
@@ -570,20 +570,20 @@ func TestMap_TypeSafety(t *testing.T) {
stringIntMap := NewMap[string, int]() stringIntMap := NewMap[string, int]()
stringIntMap.Set("key", 42) stringIntMap.Set("key", 42)
value, ok := stringIntMap.Get("key") value, ok := stringIntMap.Get("key")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 42, value) require.Equal(t, 42, value)
intStringMap := NewMap[int, string]() intStringMap := NewMap[int, string]()
intStringMap.Set(42, "value") intStringMap.Set(42, "value")
strValue, ok := intStringMap.Get(42) strValue, ok := intStringMap.Get(42)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "value", strValue) require.Equal(t, "value", strValue)
structMap := NewMap[string, struct{ Name string }]() structMap := NewMap[string, struct{ Name string }]()
structMap.Set("key", struct{ Name string }{Name: "test"}) structMap.Set("key", struct{ Name string }{Name: "test"})
structValue, ok := structMap.Get("key") structValue, ok := structMap.Get("key")
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "test", structValue.Name) require.Equal(t, "test", structValue.Name)
} }
func TestMap_InterfaceCompliance(t *testing.T) { func TestMap_InterfaceCompliance(t *testing.T) {

View File

@@ -6,7 +6,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -25,7 +24,7 @@ func TestLazySlice_Seq(t *testing.T) {
result = append(result, v) result = append(result, v)
} }
assert.Equal(t, data, result) require.Equal(t, data, result)
} }
func TestLazySlice_SeqWaitsForLoading(t *testing.T) { func TestLazySlice_SeqWaitsForLoading(t *testing.T) {
@@ -41,15 +40,15 @@ func TestLazySlice_SeqWaitsForLoading(t *testing.T) {
return data return data
}) })
assert.False(t, loaded.Load(), "should not be loaded immediately") require.False(t, loaded.Load(), "should not be loaded immediately")
var result []string var result []string
for v := range s.Seq() { for v := range s.Seq() {
result = append(result, v) result = append(result, v)
} }
assert.True(t, loaded.Load(), "should be loaded after Seq") require.True(t, loaded.Load(), "should be loaded after Seq")
assert.Equal(t, data, result) require.Equal(t, data, result)
} }
func TestLazySlice_EmptySlice(t *testing.T) { func TestLazySlice_EmptySlice(t *testing.T) {
@@ -64,7 +63,7 @@ func TestLazySlice_EmptySlice(t *testing.T) {
result = append(result, v) result = append(result, v)
} }
assert.Empty(t, result) require.Empty(t, result)
} }
func TestLazySlice_EarlyBreak(t *testing.T) { func TestLazySlice_EarlyBreak(t *testing.T) {
@@ -85,25 +84,25 @@ func TestLazySlice_EarlyBreak(t *testing.T) {
} }
} }
assert.Equal(t, []string{"a", "b"}, result) require.Equal(t, []string{"a", "b"}, result)
} }
func TestSlice(t *testing.T) { func TestSlice(t *testing.T) {
t.Run("NewSlice", func(t *testing.T) { t.Run("NewSlice", func(t *testing.T) {
s := NewSlice[int]() s := NewSlice[int]()
assert.Equal(t, 0, s.Len()) require.Equal(t, 0, s.Len())
}) })
t.Run("NewSliceFrom", func(t *testing.T) { t.Run("NewSliceFrom", func(t *testing.T) {
original := []int{1, 2, 3} original := []int{1, 2, 3}
s := NewSliceFrom(original) s := NewSliceFrom(original)
assert.Equal(t, 3, s.Len()) require.Equal(t, 3, s.Len())
// Verify it's a copy, not a reference // Verify it's a copy, not a reference
original[0] = 999 original[0] = 999
val, ok := s.Get(0) val, ok := s.Get(0)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, 1, val) require.Equal(t, 1, val)
}) })
t.Run("Append", func(t *testing.T) { t.Run("Append", func(t *testing.T) {
@@ -111,14 +110,14 @@ func TestSlice(t *testing.T) {
s.Append("hello") s.Append("hello")
s.Append("world") s.Append("world")
assert.Equal(t, 2, s.Len()) require.Equal(t, 2, s.Len())
val, ok := s.Get(0) val, ok := s.Get(0)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "hello", val) require.Equal(t, "hello", val)
val, ok = s.Get(1) val, ok = s.Get(1)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "world", val) require.Equal(t, "world", val)
}) })
t.Run("Prepend", func(t *testing.T) { t.Run("Prepend", func(t *testing.T) {
@@ -126,14 +125,14 @@ func TestSlice(t *testing.T) {
s.Append("world") s.Append("world")
s.Prepend("hello") s.Prepend("hello")
assert.Equal(t, 2, s.Len()) require.Equal(t, 2, s.Len())
val, ok := s.Get(0) val, ok := s.Get(0)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "hello", val) require.Equal(t, "hello", val)
val, ok = s.Get(1) val, ok = s.Get(1)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "world", val) require.Equal(t, "world", val)
}) })
t.Run("Delete", func(t *testing.T) { t.Run("Delete", func(t *testing.T) {
@@ -141,22 +140,22 @@ func TestSlice(t *testing.T) {
// Delete middle element // Delete middle element
ok := s.Delete(2) ok := s.Delete(2)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, 4, s.Len()) require.Equal(t, 4, s.Len())
expected := []int{1, 2, 4, 5} expected := []int{1, 2, 4, 5}
actual := s.Slice() actual := s.Slice()
assert.Equal(t, expected, actual) require.Equal(t, expected, actual)
// Delete out of bounds // Delete out of bounds
ok = s.Delete(10) ok = s.Delete(10)
assert.False(t, ok) require.False(t, ok)
assert.Equal(t, 4, s.Len()) require.Equal(t, 4, s.Len())
// Delete negative index // Delete negative index
ok = s.Delete(-1) ok = s.Delete(-1)
assert.False(t, ok) require.False(t, ok)
assert.Equal(t, 4, s.Len()) require.Equal(t, 4, s.Len())
}) })
t.Run("Get", func(t *testing.T) { t.Run("Get", func(t *testing.T) {
@@ -164,34 +163,34 @@ func TestSlice(t *testing.T) {
val, ok := s.Get(1) val, ok := s.Get(1)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "b", val) require.Equal(t, "b", val)
// Out of bounds // Out of bounds
_, ok = s.Get(10) _, ok = s.Get(10)
assert.False(t, ok) require.False(t, ok)
// Negative index // Negative index
_, ok = s.Get(-1) _, ok = s.Get(-1)
assert.False(t, ok) require.False(t, ok)
}) })
t.Run("Set", func(t *testing.T) { t.Run("Set", func(t *testing.T) {
s := NewSliceFrom([]string{"a", "b", "c"}) s := NewSliceFrom([]string{"a", "b", "c"})
ok := s.Set(1, "modified") ok := s.Set(1, "modified")
assert.True(t, ok) require.True(t, ok)
val, ok := s.Get(1) val, ok := s.Get(1)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "modified", val) require.Equal(t, "modified", val)
// Out of bounds // Out of bounds
ok = s.Set(10, "invalid") ok = s.Set(10, "invalid")
assert.False(t, ok) require.False(t, ok)
// Negative index // Negative index
ok = s.Set(-1, "invalid") ok = s.Set(-1, "invalid")
assert.False(t, ok) require.False(t, ok)
}) })
t.Run("SetSlice", func(t *testing.T) { t.Run("SetSlice", func(t *testing.T) {
@@ -202,22 +201,22 @@ func TestSlice(t *testing.T) {
newItems := []int{10, 20, 30} newItems := []int{10, 20, 30}
s.SetSlice(newItems) s.SetSlice(newItems)
assert.Equal(t, 3, s.Len()) require.Equal(t, 3, s.Len())
assert.Equal(t, newItems, s.Slice()) require.Equal(t, newItems, s.Slice())
// Verify it's a copy // Verify it's a copy
newItems[0] = 999 newItems[0] = 999
val, ok := s.Get(0) val, ok := s.Get(0)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, 10, val) require.Equal(t, 10, val)
}) })
t.Run("Clear", func(t *testing.T) { t.Run("Clear", func(t *testing.T) {
s := NewSliceFrom([]int{1, 2, 3}) s := NewSliceFrom([]int{1, 2, 3})
assert.Equal(t, 3, s.Len()) require.Equal(t, 3, s.Len())
s.Clear() s.Clear()
assert.Equal(t, 0, s.Len()) require.Equal(t, 0, s.Len())
}) })
t.Run("Slice", func(t *testing.T) { t.Run("Slice", func(t *testing.T) {
@@ -225,13 +224,13 @@ func TestSlice(t *testing.T) {
s := NewSliceFrom(original) s := NewSliceFrom(original)
copy := s.Slice() copy := s.Slice()
assert.Equal(t, original, copy) require.Equal(t, original, copy)
// Verify it's a copy // Verify it's a copy
copy[0] = 999 copy[0] = 999
val, ok := s.Get(0) val, ok := s.Get(0)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, 1, val) require.Equal(t, 1, val)
}) })
t.Run("Seq", func(t *testing.T) { t.Run("Seq", func(t *testing.T) {
@@ -242,7 +241,7 @@ func TestSlice(t *testing.T) {
result = append(result, v) result = append(result, v)
} }
assert.Equal(t, []int{1, 2, 3}, result) require.Equal(t, []int{1, 2, 3}, result)
}) })
t.Run("SeqWithIndex", func(t *testing.T) { t.Run("SeqWithIndex", func(t *testing.T) {
@@ -255,8 +254,8 @@ func TestSlice(t *testing.T) {
values = append(values, v) values = append(values, v)
} }
assert.Equal(t, []int{0, 1, 2}, indices) require.Equal(t, []int{0, 1, 2}, indices)
assert.Equal(t, []string{"a", "b", "c"}, values) require.Equal(t, []string{"a", "b", "c"}, values)
}) })
t.Run("ConcurrentAccess", func(t *testing.T) { t.Run("ConcurrentAccess", func(t *testing.T) {
@@ -291,6 +290,6 @@ func TestSlice(t *testing.T) {
wg.Wait() wg.Wait()
// Should have all items // Should have all items
assert.Equal(t, numGoroutines*itemsPerGoroutine, s.Len()) require.Equal(t, numGoroutines*itemsPerGoroutine, s.Len())
}) })
} }

View File

@@ -1,26 +1,24 @@
package env package env
import ( import (
"os"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestOsEnv_Get(t *testing.T) { func TestOsEnv_Get(t *testing.T) {
env := New() env := New()
// Test getting an existing environment variable // Test getting an existing environment variable
os.Setenv("TEST_VAR", "test_value") t.Setenv("TEST_VAR", "test_value")
defer os.Unsetenv("TEST_VAR")
value := env.Get("TEST_VAR") value := env.Get("TEST_VAR")
assert.Equal(t, "test_value", value) require.Equal(t, "test_value", value)
// Test getting a non-existent environment variable // Test getting a non-existent environment variable
value = env.Get("NON_EXISTENT_VAR") value = env.Get("NON_EXISTENT_VAR")
assert.Equal(t, "", value) require.Equal(t, "", value)
} }
func TestOsEnv_Env(t *testing.T) { func TestOsEnv_Env(t *testing.T) {
@@ -29,12 +27,12 @@ func TestOsEnv_Env(t *testing.T) {
envVars := env.Env() envVars := env.Env()
// Environment should not be empty in normal circumstances // Environment should not be empty in normal circumstances
assert.NotNil(t, envVars) require.NotNil(t, envVars)
assert.Greater(t, len(envVars), 0) require.Greater(t, len(envVars), 0)
// Each environment variable should be in key=value format // Each environment variable should be in key=value format
for _, envVar := range envVars { for _, envVar := range envVars {
assert.Contains(t, envVar, "=") require.Contains(t, envVar, "=")
} }
} }
@@ -45,8 +43,8 @@ func TestNewFromMap(t *testing.T) {
} }
env := NewFromMap(testMap) env := NewFromMap(testMap)
assert.NotNil(t, env) require.NotNil(t, env)
assert.IsType(t, &mapEnv{}, env) require.IsType(t, &mapEnv{}, env)
} }
func TestMapEnv_Get(t *testing.T) { func TestMapEnv_Get(t *testing.T) {
@@ -58,11 +56,11 @@ func TestMapEnv_Get(t *testing.T) {
env := NewFromMap(testMap) env := NewFromMap(testMap)
// Test getting existing keys // Test getting existing keys
assert.Equal(t, "value1", env.Get("KEY1")) require.Equal(t, "value1", env.Get("KEY1"))
assert.Equal(t, "value2", env.Get("KEY2")) require.Equal(t, "value2", env.Get("KEY2"))
// Test getting non-existent key // Test getting non-existent key
assert.Equal(t, "", env.Get("NON_EXISTENT")) require.Equal(t, "", env.Get("NON_EXISTENT"))
} }
func TestMapEnv_Env(t *testing.T) { func TestMapEnv_Env(t *testing.T) {
@@ -75,30 +73,30 @@ func TestMapEnv_Env(t *testing.T) {
env := NewFromMap(testMap) env := NewFromMap(testMap)
envVars := env.Env() envVars := env.Env()
assert.Len(t, envVars, 2) require.Len(t, envVars, 2)
// Convert to map for easier testing (order is not guaranteed) // Convert to map for easier testing (order is not guaranteed)
envMap := make(map[string]string) envMap := make(map[string]string)
for _, envVar := range envVars { for _, envVar := range envVars {
parts := strings.SplitN(envVar, "=", 2) parts := strings.SplitN(envVar, "=", 2)
assert.Len(t, parts, 2) require.Len(t, parts, 2)
envMap[parts[0]] = parts[1] envMap[parts[0]] = parts[1]
} }
assert.Equal(t, "value1", envMap["KEY1"]) require.Equal(t, "value1", envMap["KEY1"])
assert.Equal(t, "value2", envMap["KEY2"]) require.Equal(t, "value2", envMap["KEY2"])
}) })
t.Run("empty map", func(t *testing.T) { t.Run("empty map", func(t *testing.T) {
env := NewFromMap(map[string]string{}) env := NewFromMap(map[string]string{})
envVars := env.Env() envVars := env.Env()
assert.Nil(t, envVars) require.Nil(t, envVars)
}) })
t.Run("nil map", func(t *testing.T) { t.Run("nil map", func(t *testing.T) {
env := NewFromMap(nil) env := NewFromMap(nil)
envVars := env.Env() envVars := env.Env()
assert.Nil(t, envVars) require.Nil(t, envVars)
}) })
} }
@@ -111,8 +109,8 @@ func TestMapEnv_GetEmptyValue(t *testing.T) {
env := NewFromMap(testMap) env := NewFromMap(testMap)
// Test that empty values are returned correctly // Test that empty values are returned correctly
assert.Equal(t, "", env.Get("EMPTY_KEY")) require.Equal(t, "", env.Get("EMPTY_KEY"))
assert.Equal(t, "value", env.Get("NORMAL_KEY")) require.Equal(t, "value", env.Get("NORMAL_KEY"))
} }
func TestMapEnv_EnvFormat(t *testing.T) { func TestMapEnv_EnvFormat(t *testing.T) {
@@ -124,7 +122,7 @@ func TestMapEnv_EnvFormat(t *testing.T) {
env := NewFromMap(testMap) env := NewFromMap(testMap)
envVars := env.Env() envVars := env.Env()
assert.Len(t, envVars, 2) require.Len(t, envVars, 2)
// Check that the format is correct even with special characters // Check that the format is correct even with special characters
found := make(map[string]bool) found := make(map[string]bool)
@@ -137,6 +135,6 @@ func TestMapEnv_EnvFormat(t *testing.T) {
} }
} }
assert.True(t, found["equals"], "Should handle values with equals signs") require.True(t, found["equals"], "Should handle values with equals signs")
assert.True(t, found["spaces"], "Should handle values with spaces") require.True(t, found["spaces"], "Should handle values with spaces")
} }

View File

@@ -97,8 +97,7 @@ func TestProcessContextPaths(t *testing.T) {
// Test with tilde expansion (if we can create a file in home directory) // Test with tilde expansion (if we can create a file in home directory)
tmpDir = t.TempDir() tmpDir = t.TempDir()
rollback := setHomeEnv(tmpDir) setHomeEnv(t, tmpDir)
defer rollback()
homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt") homeTestFile := filepath.Join(tmpDir, "crush_test_file.txt")
err = os.WriteFile(homeTestFile, []byte(testContent), 0o644) err = os.WriteFile(homeTestFile, []byte(testContent), 0o644)
if err == nil { if err == nil {
@@ -114,12 +113,11 @@ func TestProcessContextPaths(t *testing.T) {
} }
} }
func setHomeEnv(path string) (rollback func()) { func setHomeEnv(tb testing.TB, path string) {
tb.Helper()
key := "HOME" key := "HOME"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
key = "USERPROFILE" key = "USERPROFILE"
} }
original := os.Getenv(key) tb.Setenv(key, path)
os.Setenv(key, path)
return func() { os.Setenv(key, original) }
} }

View File

@@ -4,7 +4,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -18,9 +17,9 @@ func TestShellPerformanceComparison(t *testing.T) {
duration := time.Since(start) duration := time.Since(start)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, exitCode) require.Equal(t, 0, exitCode)
assert.Contains(t, stdout, "hello") require.Contains(t, stdout, "hello")
assert.Empty(t, stderr) require.Empty(t, stderr)
t.Logf("Quick command took: %v", duration) t.Logf("Quick command took: %v", duration)
} }