mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
Merge remote-tracking branch 'origin/main' into list
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/db"
|
||||
"github.com/charmbracelet/crush/internal/format"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
@@ -37,8 +38,7 @@ type App struct {
|
||||
|
||||
clientsMutex sync.RWMutex
|
||||
|
||||
watcherCancelFuncs []context.CancelFunc
|
||||
cancelFuncsMutex sync.Mutex
|
||||
watcherCancelFuncs *csync.Slice[context.CancelFunc]
|
||||
lspWatcherWG sync.WaitGroup
|
||||
|
||||
config *config.Config
|
||||
@@ -76,6 +76,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
|
||||
|
||||
config: cfg,
|
||||
|
||||
watcherCancelFuncs: csync.NewSlice[context.CancelFunc](),
|
||||
|
||||
events: make(chan tea.Msg, 100),
|
||||
serviceEventsWG: &sync.WaitGroup{},
|
||||
tuiWG: &sync.WaitGroup{},
|
||||
@@ -305,11 +307,9 @@ func (app *App) Shutdown() {
|
||||
app.CoderAgent.CancelAll()
|
||||
}
|
||||
|
||||
app.cancelFuncsMutex.Lock()
|
||||
for _, cancel := range app.watcherCancelFuncs {
|
||||
for cancel := range app.watcherCancelFuncs.Seq() {
|
||||
cancel()
|
||||
}
|
||||
app.cancelFuncsMutex.Unlock()
|
||||
|
||||
// Wait for all LSP watchers to finish.
|
||||
app.lspWatcherWG.Wait()
|
||||
|
||||
@@ -63,9 +63,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman
|
||||
workspaceWatcher := watcher.NewWorkspaceWatcher(name, lspClient)
|
||||
|
||||
// Store the cancel function to be called during cleanup.
|
||||
app.cancelFuncsMutex.Lock()
|
||||
app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc)
|
||||
app.cancelFuncsMutex.Unlock()
|
||||
app.watcherCancelFuncs.Append(cancelFunc)
|
||||
|
||||
// Add to map with mutex protection before starting goroutine
|
||||
app.clientsMutex.Lock()
|
||||
|
||||
@@ -270,7 +270,7 @@ func (c *Config) WorkingDir() string {
|
||||
|
||||
func (c *Config) EnabledProviders() []ProviderConfig {
|
||||
var enabled []ProviderConfig
|
||||
for _, p := range c.Providers.Seq2() {
|
||||
for p := range c.Providers.Seq() {
|
||||
if !p.Disable {
|
||||
enabled = append(enabled, p)
|
||||
}
|
||||
|
||||
@@ -56,6 +56,15 @@ func (m *Map[K, V]) Len() int {
|
||||
return len(m.inner)
|
||||
}
|
||||
|
||||
// Take gets an item and then deletes it.
|
||||
func (m *Map[K, V]) Take(key K) (V, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.inner[key]
|
||||
delete(m.inner, key)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Seq2 returns an iter.Seq2 that yields key-value pairs from the map.
|
||||
func (m *Map[K, V]) Seq2() iter.Seq2[K, V] {
|
||||
dst := make(map[K]V)
|
||||
@@ -71,6 +80,17 @@ func (m *Map[K, V]) Seq2() iter.Seq2[K, V] {
|
||||
}
|
||||
}
|
||||
|
||||
// Seq returns an iter.Seq that yields values from the map.
|
||||
func (m *Map[K, V]) Seq() iter.Seq[V] {
|
||||
return func(yield func(V) bool) {
|
||||
for _, v := range m.Seq2() {
|
||||
if !yield(v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Unmarshaler = &Map[string, any]{}
|
||||
_ json.Marshaler = &Map[string, any]{}
|
||||
|
||||
@@ -110,6 +110,72 @@ func TestMap_Len(t *testing.T) {
|
||||
assert.Equal(t, 0, m.Len())
|
||||
}
|
||||
|
||||
func TestMap_Take(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[string, int]()
|
||||
m.Set("key1", 42)
|
||||
m.Set("key2", 100)
|
||||
|
||||
assert.Equal(t, 2, m.Len())
|
||||
|
||||
value, ok := m.Take("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, value)
|
||||
assert.Equal(t, 1, m.Len())
|
||||
|
||||
_, exists := m.Get("key1")
|
||||
assert.False(t, exists)
|
||||
|
||||
value, ok = m.Get("key2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 100, value)
|
||||
}
|
||||
|
||||
func TestMap_Take_NonexistentKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[string, int]()
|
||||
m.Set("key1", 42)
|
||||
|
||||
value, ok := m.Take("nonexistent")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, value)
|
||||
assert.Equal(t, 1, m.Len())
|
||||
|
||||
value, ok = m.Get("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, value)
|
||||
}
|
||||
|
||||
func TestMap_Take_EmptyMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[string, int]()
|
||||
|
||||
value, ok := m.Take("key1")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, value)
|
||||
assert.Equal(t, 0, m.Len())
|
||||
}
|
||||
|
||||
func TestMap_Take_SameKeyTwice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[string, int]()
|
||||
m.Set("key1", 42)
|
||||
|
||||
value, ok := m.Take("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, value)
|
||||
assert.Equal(t, 0, m.Len())
|
||||
|
||||
value, ok = m.Take("key1")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, value)
|
||||
assert.Equal(t, 0, m.Len())
|
||||
}
|
||||
|
||||
func TestMap_Seq2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -158,6 +224,57 @@ func TestMap_Seq2_EmptyMap(t *testing.T) {
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestMap_Seq(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[string, int]()
|
||||
m.Set("key1", 1)
|
||||
m.Set("key2", 2)
|
||||
m.Set("key3", 3)
|
||||
|
||||
collected := make([]int, 0)
|
||||
for v := range m.Seq() {
|
||||
collected = append(collected, v)
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(collected))
|
||||
assert.Contains(t, collected, 1)
|
||||
assert.Contains(t, collected, 2)
|
||||
assert.Contains(t, collected, 3)
|
||||
}
|
||||
|
||||
func TestMap_Seq_EarlyReturn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[string, int]()
|
||||
m.Set("key1", 1)
|
||||
m.Set("key2", 2)
|
||||
m.Set("key3", 3)
|
||||
|
||||
count := 0
|
||||
for range m.Seq() {
|
||||
count++
|
||||
if count == 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestMap_Seq_EmptyMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[string, int]()
|
||||
|
||||
count := 0
|
||||
for range m.Seq() {
|
||||
count++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestMap_MarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -371,6 +488,82 @@ func TestMap_ConcurrentSeq2(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMap_ConcurrentSeq(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[int, int]()
|
||||
for i := range 100 {
|
||||
m.Set(i, i*2)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numIterators = 10
|
||||
|
||||
wg.Add(numIterators)
|
||||
for range numIterators {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
count := 0
|
||||
values := make(map[int]bool)
|
||||
for v := range m.Seq() {
|
||||
values[v] = true
|
||||
count++
|
||||
}
|
||||
assert.Equal(t, 100, count)
|
||||
for i := range 100 {
|
||||
assert.True(t, values[i*2])
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMap_ConcurrentTake(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewMap[int, int]()
|
||||
const numItems = 1000
|
||||
|
||||
for i := range numItems {
|
||||
m.Set(i, i*2)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numWorkers = 10
|
||||
taken := make([][]int, numWorkers)
|
||||
|
||||
wg.Add(numWorkers)
|
||||
for i := range numWorkers {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
taken[workerID] = make([]int, 0)
|
||||
for j := workerID; j < numItems; j += numWorkers {
|
||||
if value, ok := m.Take(j); ok {
|
||||
taken[workerID] = append(taken[workerID], value)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 0, m.Len())
|
||||
|
||||
allTaken := make(map[int]bool)
|
||||
for _, workerTaken := range taken {
|
||||
for _, value := range workerTaken {
|
||||
assert.False(t, allTaken[value], "Value %d was taken multiple times", value)
|
||||
allTaken[value] = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, numItems, len(allTaken))
|
||||
for i := range numItems {
|
||||
assert.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_TypeSafety(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -431,6 +624,38 @@ func BenchmarkMap_Seq2(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMap_Seq(b *testing.B) {
|
||||
m := NewMap[int, int]()
|
||||
for i := range 1000 {
|
||||
m.Set(i, i*2)
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
for range m.Seq() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMap_Take(b *testing.B) {
|
||||
m := NewMap[int, int]()
|
||||
for i := range 1000 {
|
||||
m.Set(i, i*2)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; b.Loop(); i++ {
|
||||
key := i % 1000
|
||||
m.Take(key)
|
||||
if i%1000 == 999 {
|
||||
b.StopTimer()
|
||||
for j := range 1000 {
|
||||
m.Set(j, j*2)
|
||||
}
|
||||
b.StartTimer()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
|
||||
m := NewMap[int, int]()
|
||||
for i := range 1000 {
|
||||
|
||||
@@ -59,10 +59,10 @@ func NewSliceFrom[T any](s []T) *Slice[T] {
|
||||
}
|
||||
|
||||
// Append adds an element to the end of the slice.
|
||||
func (s *Slice[T]) Append(item T) {
|
||||
func (s *Slice[T]) Append(items ...T) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.inner = append(s.inner, item)
|
||||
s.inner = append(s.inner, items...)
|
||||
}
|
||||
|
||||
// Prepend adds an element to the beginning of the slice.
|
||||
@@ -112,6 +112,15 @@ func (s *Slice[T]) Len() int {
|
||||
return len(s.inner)
|
||||
}
|
||||
|
||||
// Slice returns a copy of the underlying slice.
|
||||
func (s *Slice[T]) Slice() []T {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]T, len(s.inner))
|
||||
copy(result, s.inner)
|
||||
return result
|
||||
}
|
||||
|
||||
// SetSlice replaces the entire slice with a new one.
|
||||
func (s *Slice[T]) SetSlice(items []T) {
|
||||
s.mu.Lock()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package csync
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -146,7 +145,7 @@ func TestSlice(t *testing.T) {
|
||||
assert.Equal(t, 4, s.Len())
|
||||
|
||||
expected := []int{1, 2, 4, 5}
|
||||
actual := slices.Collect(s.Seq())
|
||||
actual := s.Slice()
|
||||
assert.Equal(t, expected, actual)
|
||||
|
||||
// Delete out of bounds
|
||||
@@ -204,7 +203,7 @@ func TestSlice(t *testing.T) {
|
||||
s.SetSlice(newItems)
|
||||
|
||||
assert.Equal(t, 3, s.Len())
|
||||
assert.Equal(t, newItems, slices.Collect(s.Seq()))
|
||||
assert.Equal(t, newItems, s.Slice())
|
||||
|
||||
// Verify it's a copy
|
||||
newItems[0] = 999
|
||||
@@ -225,7 +224,7 @@ func TestSlice(t *testing.T) {
|
||||
original := []int{1, 2, 3}
|
||||
s := NewSliceFrom(original)
|
||||
|
||||
copy := slices.Collect(s.Seq())
|
||||
copy := s.Slice()
|
||||
assert.Equal(t, original, copy)
|
||||
|
||||
// Verify it's a copy
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
@@ -78,7 +77,7 @@ type agent struct {
|
||||
summarizeProvider provider.Provider
|
||||
summarizeProviderID string
|
||||
|
||||
activeRequests sync.Map
|
||||
activeRequests *csync.Map[string, context.CancelFunc]
|
||||
}
|
||||
|
||||
var agentPromptMap = map[string]prompt.PromptID{
|
||||
@@ -222,7 +221,7 @@ func NewAgent(
|
||||
titleProvider: titleProvider,
|
||||
summarizeProvider: summarizeProvider,
|
||||
summarizeProviderID: string(smallModelProviderCfg.ID),
|
||||
activeRequests: sync.Map{},
|
||||
activeRequests: csync.NewMap[string, context.CancelFunc](),
|
||||
tools: csync.NewLazySlice(toolFn),
|
||||
}, nil
|
||||
}
|
||||
@@ -233,38 +232,31 @@ func (a *agent) Model() catwalk.Model {
|
||||
|
||||
func (a *agent) Cancel(sessionID string) {
|
||||
// Cancel regular requests
|
||||
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
|
||||
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
||||
slog.Info("Request cancellation initiated", "session_id", sessionID)
|
||||
cancel()
|
||||
}
|
||||
if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
|
||||
slog.Info("Request cancellation initiated", "session_id", sessionID)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Also check for summarize requests
|
||||
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
|
||||
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
||||
slog.Info("Summarize cancellation initiated", "session_id", sessionID)
|
||||
cancel()
|
||||
}
|
||||
if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
|
||||
slog.Info("Summarize cancellation initiated", "session_id", sessionID)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) IsBusy() bool {
|
||||
busy := false
|
||||
a.activeRequests.Range(func(key, value any) bool {
|
||||
if cancelFunc, ok := value.(context.CancelFunc); ok {
|
||||
if cancelFunc != nil {
|
||||
busy = true
|
||||
return false
|
||||
}
|
||||
var busy bool
|
||||
for cancelFunc := range a.activeRequests.Seq() {
|
||||
if cancelFunc != nil {
|
||||
busy = true
|
||||
break
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return busy
|
||||
}
|
||||
|
||||
func (a *agent) IsSessionBusy(sessionID string) bool {
|
||||
_, busy := a.activeRequests.Load(sessionID)
|
||||
_, busy := a.activeRequests.Get(sessionID)
|
||||
return busy
|
||||
}
|
||||
|
||||
@@ -335,7 +327,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
|
||||
|
||||
genCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
a.activeRequests.Store(sessionID, cancel)
|
||||
a.activeRequests.Set(sessionID, cancel)
|
||||
go func() {
|
||||
slog.Debug("Request started", "sessionID", sessionID)
|
||||
defer log.RecoverPanic("agent.Run", func() {
|
||||
@@ -350,7 +342,7 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
|
||||
slog.Error(result.Error.Error())
|
||||
}
|
||||
slog.Debug("Request completed", "sessionID", sessionID)
|
||||
a.activeRequests.Delete(sessionID)
|
||||
a.activeRequests.Del(sessionID)
|
||||
cancel()
|
||||
a.Publish(pubsub.CreatedEvent, result)
|
||||
events <- result
|
||||
@@ -682,10 +674,10 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
|
||||
summarizeCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
// Store the cancel function in activeRequests to allow cancellation
|
||||
a.activeRequests.Store(sessionID+"-summarize", cancel)
|
||||
a.activeRequests.Set(sessionID+"-summarize", cancel)
|
||||
|
||||
go func() {
|
||||
defer a.activeRequests.Delete(sessionID + "-summarize")
|
||||
defer a.activeRequests.Del(sessionID + "-summarize")
|
||||
defer cancel()
|
||||
event := AgentEvent{
|
||||
Type: AgentEventTypeSummarize,
|
||||
@@ -850,10 +842,9 @@ func (a *agent) CancelAll() {
|
||||
if !a.IsBusy() {
|
||||
return
|
||||
}
|
||||
a.activeRequests.Range(func(key, value any) bool {
|
||||
a.Cancel(key.(string)) // key is sessionID
|
||||
return true
|
||||
})
|
||||
for key := range a.activeRequests.Seq2() {
|
||||
a.Cancel(key) // key is sessionID
|
||||
}
|
||||
|
||||
timeout := time.After(5 * time.Second)
|
||||
for a.IsBusy() {
|
||||
@@ -907,7 +898,7 @@ func (a *agent) UpdateModel() error {
|
||||
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
|
||||
var smallModelProviderCfg config.ProviderConfig
|
||||
|
||||
for _, p := range cfg.Providers.Seq2() {
|
||||
for p := range cfg.Providers.Seq() {
|
||||
if p.ID == smallModelCfg.Provider {
|
||||
smallModelProviderCfg = p
|
||||
break
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
@@ -195,9 +197,8 @@ func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *confi
|
||||
}
|
||||
|
||||
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var result []tools.BaseTool
|
||||
result := csync.NewSlice[tools.BaseTool]()
|
||||
for name, m := range cfg.MCP {
|
||||
if m.Disabled {
|
||||
slog.Debug("skipping disabled mcp", "name", name)
|
||||
@@ -218,9 +219,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
||||
mu.Unlock()
|
||||
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
||||
case config.MCPHttp:
|
||||
c, err := client.NewStreamableHttpClient(
|
||||
m.URL,
|
||||
@@ -230,9 +229,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
|
||||
slog.Error("error creating mcp client", "error", err)
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
||||
mu.Unlock()
|
||||
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
||||
case config.MCPSse:
|
||||
c, err := client.NewSSEMCPClient(
|
||||
m.URL,
|
||||
@@ -242,12 +239,10 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
|
||||
slog.Error("error creating mcp client", "error", err)
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
||||
mu.Unlock()
|
||||
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
||||
}
|
||||
}(name, m)
|
||||
}
|
||||
wg.Wait()
|
||||
return result
|
||||
return slices.Collect(result.Seq())
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
)
|
||||
|
||||
@@ -74,8 +75,7 @@ func processContextPaths(workDir string, paths []string) string {
|
||||
)
|
||||
|
||||
// Track processed files to avoid duplicates
|
||||
processedFiles := make(map[string]bool)
|
||||
var processedMutex sync.Mutex
|
||||
processedFiles := csync.NewMap[string, bool]()
|
||||
|
||||
for _, path := range paths {
|
||||
wg.Add(1)
|
||||
@@ -106,14 +106,8 @@ func processContextPaths(workDir string, paths []string) string {
|
||||
// Check if we've already processed this file (case-insensitive)
|
||||
lowerPath := strings.ToLower(path)
|
||||
|
||||
processedMutex.Lock()
|
||||
alreadyProcessed := processedFiles[lowerPath]
|
||||
if !alreadyProcessed {
|
||||
processedFiles[lowerPath] = true
|
||||
}
|
||||
processedMutex.Unlock()
|
||||
|
||||
if !alreadyProcessed {
|
||||
if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
|
||||
processedFiles.Set(lowerPath, true)
|
||||
if result := processFile(path); result != "" {
|
||||
resultCh <- result
|
||||
}
|
||||
@@ -126,14 +120,8 @@ func processContextPaths(workDir string, paths []string) string {
|
||||
// Check if we've already processed this file (case-insensitive)
|
||||
lowerPath := strings.ToLower(fullPath)
|
||||
|
||||
processedMutex.Lock()
|
||||
alreadyProcessed := processedFiles[lowerPath]
|
||||
if !alreadyProcessed {
|
||||
processedFiles[lowerPath] = true
|
||||
}
|
||||
processedMutex.Unlock()
|
||||
|
||||
if !alreadyProcessed {
|
||||
if alreadyProcessed, _ := processedFiles.Get(lowerPath); !alreadyProcessed {
|
||||
processedFiles.Set(lowerPath, true)
|
||||
result := processFile(fullPath)
|
||||
if result != "" {
|
||||
resultCh <- result
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/bmatcuk/doublestar/v4"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/lsp/protocol"
|
||||
@@ -25,8 +26,7 @@ type WorkspaceWatcher struct {
|
||||
workspacePath string
|
||||
|
||||
debounceTime time.Duration
|
||||
debounceMap map[string]*time.Timer
|
||||
debounceMu sync.Mutex
|
||||
debounceMap *csync.Map[string, *time.Timer]
|
||||
|
||||
// File watchers registered by the server
|
||||
registrations []protocol.FileSystemWatcher
|
||||
@@ -46,7 +46,7 @@ func NewWorkspaceWatcher(name string, client *lsp.Client) *WorkspaceWatcher {
|
||||
name: name,
|
||||
client: client,
|
||||
debounceTime: 300 * time.Millisecond,
|
||||
debounceMap: make(map[string]*time.Timer),
|
||||
debounceMap: csync.NewMap[string, *time.Timer](),
|
||||
registrations: []protocol.FileSystemWatcher{},
|
||||
}
|
||||
}
|
||||
@@ -639,26 +639,21 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt
|
||||
|
||||
// debounceHandleFileEvent handles file events with debouncing to reduce notifications
|
||||
func (w *WorkspaceWatcher) debounceHandleFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) {
|
||||
w.debounceMu.Lock()
|
||||
defer w.debounceMu.Unlock()
|
||||
|
||||
// Create a unique key based on URI and change type
|
||||
key := fmt.Sprintf("%s:%d", uri, changeType)
|
||||
|
||||
// Cancel existing timer if any
|
||||
if timer, exists := w.debounceMap[key]; exists {
|
||||
if timer, exists := w.debounceMap.Get(key); exists {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
// Create new timer
|
||||
w.debounceMap[key] = time.AfterFunc(w.debounceTime, func() {
|
||||
w.debounceMap.Set(key, time.AfterFunc(w.debounceTime, func() {
|
||||
w.handleFileEvent(ctx, uri, changeType)
|
||||
|
||||
// Cleanup timer after execution
|
||||
w.debounceMu.Lock()
|
||||
delete(w.debounceMap, key)
|
||||
w.debounceMu.Unlock()
|
||||
})
|
||||
w.debounceMap.Del(key)
|
||||
}))
|
||||
}
|
||||
|
||||
// handleFileEvent sends file change notifications
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -46,7 +47,7 @@ type permissionService struct {
|
||||
workingDir string
|
||||
sessionPermissions []PermissionRequest
|
||||
sessionPermissionsMu sync.RWMutex
|
||||
pendingRequests sync.Map
|
||||
pendingRequests *csync.Map[string, chan bool]
|
||||
autoApproveSessions []string
|
||||
autoApproveSessionsMu sync.RWMutex
|
||||
skip bool
|
||||
@@ -54,9 +55,9 @@ type permissionService struct {
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
respCh, ok := s.pendingRequests.Get(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
respCh <- true
|
||||
}
|
||||
|
||||
s.sessionPermissionsMu.Lock()
|
||||
@@ -65,16 +66,16 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
|
||||
}
|
||||
|
||||
func (s *permissionService) Grant(permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
respCh, ok := s.pendingRequests.Get(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
respCh <- true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) Deny(permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
respCh, ok := s.pendingRequests.Get(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- false
|
||||
respCh <- false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,8 +123,8 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
|
||||
respCh := make(chan bool, 1)
|
||||
|
||||
s.pendingRequests.Store(permission.ID, respCh)
|
||||
defer s.pendingRequests.Delete(permission.ID)
|
||||
s.pendingRequests.Set(permission.ID, respCh)
|
||||
defer s.pendingRequests.Del(permission.ID)
|
||||
|
||||
s.Publish(pubsub.CreatedEvent, permission)
|
||||
|
||||
@@ -144,5 +145,6 @@ func NewPermissionService(workingDir string, skip bool, allowedTools []string) S
|
||||
sessionPermissions: make([]PermissionRequest, 0),
|
||||
skip: skip,
|
||||
allowedTools: allowedTools,
|
||||
pendingRequests: csync.NewMap[string, chan bool](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/diff"
|
||||
"github.com/charmbracelet/crush/internal/fsext"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
@@ -71,8 +72,7 @@ type sidebarCmp struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
compactMode bool
|
||||
history history.Service
|
||||
// Using a sync map here because we might receive file history events concurrently
|
||||
files sync.Map
|
||||
files *csync.Map[string, SessionFile]
|
||||
}
|
||||
|
||||
func New(history history.Service, lspClients map[string]*lsp.Client, compact bool) Sidebar {
|
||||
@@ -80,6 +80,7 @@ func New(history history.Service, lspClients map[string]*lsp.Client, compact boo
|
||||
lspClients: lspClients,
|
||||
history: history,
|
||||
compactMode: compact,
|
||||
files: csync.NewMap[string, SessionFile](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,9 +91,9 @@ func (m *sidebarCmp) Init() tea.Cmd {
|
||||
func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case SessionFilesMsg:
|
||||
m.files = sync.Map{}
|
||||
m.files = csync.NewMap[string, SessionFile]()
|
||||
for _, file := range msg.Files {
|
||||
m.files.Store(file.FilePath, file)
|
||||
m.files.Set(file.FilePath, file)
|
||||
}
|
||||
return m, nil
|
||||
|
||||
@@ -178,31 +179,30 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te
|
||||
return func() tea.Msg {
|
||||
file := event.Payload
|
||||
found := false
|
||||
m.files.Range(func(key, value any) bool {
|
||||
existing := value.(SessionFile)
|
||||
if existing.FilePath == file.Path {
|
||||
if existing.History.latestVersion.Version < file.Version {
|
||||
existing.History.latestVersion = file
|
||||
} else if file.Version == 0 {
|
||||
existing.History.initialVersion = file
|
||||
} else {
|
||||
// If the version is not greater than the latest, we ignore it
|
||||
return true
|
||||
}
|
||||
before := existing.History.initialVersion.Content
|
||||
after := existing.History.latestVersion.Content
|
||||
path := existing.History.initialVersion.Path
|
||||
cwd := config.Get().WorkingDir()
|
||||
path = strings.TrimPrefix(path, cwd)
|
||||
_, additions, deletions := diff.GenerateDiff(before, after, path)
|
||||
existing.Additions = additions
|
||||
existing.Deletions = deletions
|
||||
m.files.Store(file.Path, existing)
|
||||
found = true
|
||||
return false
|
||||
for existing := range m.files.Seq() {
|
||||
if existing.FilePath != file.Path {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
})
|
||||
if existing.History.latestVersion.Version < file.Version {
|
||||
existing.History.latestVersion = file
|
||||
} else if file.Version == 0 {
|
||||
existing.History.initialVersion = file
|
||||
} else {
|
||||
// If the version is not greater than the latest, we ignore it
|
||||
continue
|
||||
}
|
||||
before := existing.History.initialVersion.Content
|
||||
after := existing.History.latestVersion.Content
|
||||
path := existing.History.initialVersion.Path
|
||||
cwd := config.Get().WorkingDir()
|
||||
path = strings.TrimPrefix(path, cwd)
|
||||
_, additions, deletions := diff.GenerateDiff(before, after, path)
|
||||
existing.Additions = additions
|
||||
existing.Deletions = deletions
|
||||
m.files.Set(file.Path, existing)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
if found {
|
||||
return nil
|
||||
}
|
||||
@@ -215,7 +215,7 @@ func (m *sidebarCmp) handleFileHistoryEvent(event pubsub.Event[history.File]) te
|
||||
Additions: 0,
|
||||
Deletions: 0,
|
||||
}
|
||||
m.files.Store(file.Path, sf)
|
||||
m.files.Set(file.Path, sf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -386,12 +386,7 @@ func (m *sidebarCmp) filesBlockCompact(maxWidth int) string {
|
||||
|
||||
section := t.S().Subtle.Render("Modified Files")
|
||||
|
||||
files := make([]SessionFile, 0)
|
||||
m.files.Range(func(key, value any) bool {
|
||||
file := value.(SessionFile)
|
||||
files = append(files, file)
|
||||
return true
|
||||
})
|
||||
files := slices.Collect(m.files.Seq())
|
||||
|
||||
if len(files) == 0 {
|
||||
content := lipgloss.JoinVertical(
|
||||
@@ -620,12 +615,7 @@ func (m *sidebarCmp) filesBlock() string {
|
||||
core.Section("Modified Files", m.getMaxWidth()),
|
||||
)
|
||||
|
||||
files := make([]SessionFile, 0)
|
||||
m.files.Range(func(key, value any) bool {
|
||||
file := value.(SessionFile)
|
||||
files = append(files, file)
|
||||
return true // continue iterating
|
||||
})
|
||||
files := slices.Collect(m.files.Seq())
|
||||
if len(files) == 0 {
|
||||
return lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
|
||||
Reference in New Issue
Block a user