mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
fix: improvements
This commit is contained in:
@@ -77,38 +77,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to configure providers: %w", err)
|
||||
}
|
||||
|
||||
// Test provider connections in parallel
|
||||
var testResults sync.Map
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
slog.Info("Testing provider connections")
|
||||
defer slog.Info("Provider connection tests completed")
|
||||
for _, p := range cfg.Providers.Seq2() {
|
||||
if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
|
||||
wg.Add(1)
|
||||
go func(provider ProviderConfig) {
|
||||
defer wg.Done()
|
||||
err := provider.TestConnection(cfg.resolver)
|
||||
testResults.Store(provider.ID, err == nil)
|
||||
if err != nil {
|
||||
slog.Error("Provider connection test failed", "provider", provider.ID, "error", err)
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Remove failed providers
|
||||
testResults.Range(func(key, value any) bool {
|
||||
providerID := key.(string)
|
||||
passed := value.(bool)
|
||||
if !passed {
|
||||
cfg.Providers.Del(providerID)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}()
|
||||
go cfg.removeUnresponsiveProviders()
|
||||
|
||||
if !cfg.IsConfigured() {
|
||||
slog.Warn("No providers configured")
|
||||
@@ -122,6 +91,38 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) removeUnresponsiveProviders() {
|
||||
// Test provider connections in parallel
|
||||
var testResults sync.Map
|
||||
var wg sync.WaitGroup
|
||||
slog.Info("Testing provider connections")
|
||||
defer slog.Info("Provider connection tests completed")
|
||||
for _, p := range c.Providers.Seq2() {
|
||||
if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
|
||||
wg.Add(1)
|
||||
go func(provider ProviderConfig) {
|
||||
defer wg.Done()
|
||||
err := provider.TestConnection(c.resolver)
|
||||
testResults.Store(provider.ID, err == nil)
|
||||
if err != nil {
|
||||
slog.Error("Provider connection test failed", "provider", provider.ID, "error", err)
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Remove failed providers
|
||||
testResults.Range(func(key, value any) bool {
|
||||
providerID := key.(string)
|
||||
passed := value.(bool)
|
||||
if !passed {
|
||||
c.Providers.Del(providerID)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
|
||||
knownProviderNames := make(map[string]bool)
|
||||
for _, p := range knownProviders {
|
||||
|
||||
Reference in New Issue
Block a user