perf: concurrency improvements

This commit is contained in:
Raphael Amorim
2025-06-17 17:33:22 +02:00
parent b78bf052b4
commit 2937e194c7
5 changed files with 329 additions and 122 deletions

View File

@@ -13,8 +13,7 @@ func TestShellPerformanceComparison(t *testing.T) {
// Test quick command
start := time.Now()
stdout, stderr, err := shell.Exec(t.Context(), "echo 'hello'")
exitCode := ExitCode(err)
stdout, stderr, exitCode, _, err := shell.Exec(t.Context(), "echo 'hello'", 0)
duration := time.Since(start)
require.NoError(t, err)
@@ -33,8 +32,7 @@ func BenchmarkShellPolling(b *testing.B) {
for b.Loop() {
// Use a short sleep to measure polling overhead
_, _, err := shell.Exec(b.Context(), "sleep 0.02")
exitCode := ExitCode(err)
_, _, exitCode, _, err := shell.Exec(b.Context(), "sleep 0.02", 0)
if err != nil || exitCode != 0 {
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
}

View File

@@ -1,87 +1,323 @@
package shell
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"github.com/charmbracelet/crush/internal/logging"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/interp"
"mvdan.cc/sh/v3/syntax"
"syscall"
"time"
)
type PersistentShell struct {
env []string
cwd string
mu sync.Mutex
cmd *exec.Cmd
stdin *os.File
isAlive bool
cwd string
mu sync.Mutex
commandQueue chan *commandExecution
}
type commandExecution struct {
command string
timeout time.Duration
resultChan chan commandResult
ctx context.Context
}
type commandResult struct {
stdout string
stderr string
exitCode int
interrupted bool
err error
}
var (
once sync.Once
shellInstance *PersistentShell
shellInstance *PersistentShell
shellInstanceOnce sync.Once
)
func GetPersistentShell(cwd string) *PersistentShell {
once.Do(func() {
shellInstance = newPersistentShell(cwd)
func GetPersistentShell(workingDir string) *PersistentShell {
shellInstanceOnce.Do(func() {
shellInstance = newPersistentShell(workingDir)
})
if shellInstance == nil {
shellInstance = newPersistentShell(workingDir)
} else if !shellInstance.isAlive {
shellInstance = newPersistentShell(shellInstance.cwd)
}
return shellInstance
}
func newPersistentShell(cwd string) *PersistentShell {
return &PersistentShell{
cwd: cwd,
env: os.Environ(),
// Default to environment variable
shellPath := os.Getenv("SHELL")
if shellPath == "" {
shellPath = "/bin/bash"
}
// Default shell args
shellArgs := []string{"-l"}
cmd := exec.Command(shellPath, shellArgs...)
cmd.Dir = cwd
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return nil
}
cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
err = cmd.Start()
if err != nil {
return nil
}
shell := &PersistentShell{
cmd: cmd,
stdin: stdinPipe.(*os.File),
isAlive: true,
cwd: cwd,
commandQueue: make(chan *commandExecution, 10),
}
go func() {
defer func() {
if r := recover(); r != nil {
fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
shell.isAlive = false
close(shell.commandQueue)
}
}()
shell.processCommands()
}()
go func() {
err := cmd.Wait()
if err != nil {
// Log the error if needed
}
shell.isAlive = false
close(shell.commandQueue)
}()
return shell
}
func (s *PersistentShell) processCommands() {
for cmd := range s.commandQueue {
result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
cmd.resultChan <- result
}
}
func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
s.mu.Lock()
defer s.mu.Unlock()
line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
if err != nil {
return "", "", fmt.Errorf("could not parse command: %w", err)
if !s.isAlive {
return commandResult{
stderr: "Shell is not alive",
exitCode: 1,
err: errors.New("shell is not alive"),
}
}
var stdout, stderr bytes.Buffer
runner, err := interp.New(
interp.StdIO(nil, &stdout, &stderr),
interp.Interactive(false),
interp.Env(expand.ListEnviron(s.env...)),
interp.Dir(s.cwd),
tempDir := os.TempDir()
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
defer func() {
os.Remove(stdoutFile)
os.Remove(stderrFile)
os.Remove(statusFile)
os.Remove(cwdFile)
}()
fullCommand := fmt.Sprintf(`
eval %s < /dev/null > %s 2> %s
EXEC_EXIT_CODE=$?
pwd > %s
echo $EXEC_EXIT_CODE > %s
`,
shellQuote(command),
shellQuote(stdoutFile),
shellQuote(stderrFile),
shellQuote(cwdFile),
shellQuote(statusFile),
)
_, err := s.stdin.Write([]byte(fullCommand + "\n"))
if err != nil {
return "", "", fmt.Errorf("could not run command: %w", err)
return commandResult{
stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
exitCode: 1,
err: err,
}
}
err = runner.Run(ctx, line)
s.cwd = runner.Dir
s.env = []string{}
for name, vr := range runner.Vars {
s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
interrupted := false
startTime := time.Now()
done := make(chan bool)
go func() {
// Use exponential backoff polling
pollInterval := 1 * time.Millisecond
maxPollInterval := 100 * time.Millisecond
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
s.killChildren()
interrupted = true
done <- true
return
case <-ticker.C:
if fileExists(statusFile) && fileSize(statusFile) > 0 {
done <- true
return
}
if timeout > 0 {
elapsed := time.Since(startTime)
if elapsed > timeout {
s.killChildren()
interrupted = true
done <- true
return
}
}
// Exponential backoff to reduce CPU usage for longer-running commands
if pollInterval < maxPollInterval {
pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
ticker.Reset(pollInterval)
}
}
}
}()
<-done
stdout := readFileOrEmpty(stdoutFile)
stderr := readFileOrEmpty(stderrFile)
exitCodeStr := readFileOrEmpty(statusFile)
newCwd := readFileOrEmpty(cwdFile)
exitCode := 0
if exitCodeStr != "" {
fmt.Sscanf(exitCodeStr, "%d", &exitCode)
} else if interrupted {
exitCode = 143
stderr += "\nCommand execution timed out or was interrupted"
}
if newCwd != "" {
s.cwd = strings.TrimSpace(newCwd)
}
return commandResult{
stdout: stdout,
stderr: stderr,
exitCode: exitCode,
interrupted: interrupted,
}
logging.InfoPersist("Command finished", "command", command, "err", err)
return stdout.String(), stderr.String(), err
}
func IsInterrupt(err error) bool {
return errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded)
func (s *PersistentShell) killChildren() {
if s.cmd == nil || s.cmd.Process == nil {
return
}
pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
output, err := pgrepCmd.Output()
if err != nil {
return
}
for pidStr := range strings.SplitSeq(string(output), "\n") {
if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
var pid int
fmt.Sscanf(pidStr, "%d", &pid)
if pid > 0 {
proc, err := os.FindProcess(pid)
if err == nil {
proc.Signal(syscall.SIGTERM)
}
}
}
}
}
func ExitCode(err error) int {
if err == nil {
func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
if !s.isAlive {
return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
}
timeout := time.Duration(timeoutMs) * time.Millisecond
resultChan := make(chan commandResult)
s.commandQueue <- &commandExecution{
command: command,
timeout: timeout,
resultChan: resultChan,
ctx: ctx,
}
result := <-resultChan
return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
}
func (s *PersistentShell) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if !s.isAlive {
return
}
s.stdin.Write([]byte("exit\n"))
s.cmd.Process.Kill()
s.isAlive = false
}
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
}
func readFileOrEmpty(path string) string {
content, err := os.ReadFile(path)
if err != nil {
return ""
}
return string(content)
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func fileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {
return 0
}
status, ok := interp.IsExitStatus(err)
if ok {
return int(status)
}
return 1
}
return info.Size()
}

View File

@@ -2,81 +2,28 @@ package shell
import (
"context"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// Benchmark to measure CPU efficiency
func BenchmarkShellQuickCommands(b *testing.B) {
shell := newPersistentShell(b.TempDir())
tmpDir, err := os.MkdirTemp("", "shell-bench")
require.NoError(b, err)
defer os.RemoveAll(tmpDir)
shell := GetPersistentShell(tmpDir)
defer shell.Close()
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
_, _, err := shell.Exec(context.Background(), "echo test")
exitCode := ExitCode(err)
_, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0)
if err != nil || exitCode != 0 {
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
}
}
}
func TestTestTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
t.Cleanup(cancel)
shell := newPersistentShell(t.TempDir())
_, _, err := shell.Exec(ctx, "sleep 10")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
}
if !IsInterrupt(err) {
t.Fatalf("Expected command to be interrupted, but it was not")
}
if err == nil {
t.Fatalf("Expected an error due to timeout, but got none")
}
}
func TestTestCancel(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
cancel() // immediately cancel the context
shell := newPersistentShell(t.TempDir())
_, _, err := shell.Exec(ctx, "sleep 10")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
}
if !IsInterrupt(err) {
t.Fatalf("Expected command to be interrupted, but it was not")
}
if err == nil {
t.Fatalf("Expected an error due to cancel, but got none")
}
}
func TestRunCommandError(t *testing.T) {
shell := newPersistentShell(t.TempDir())
_, _, err := shell.Exec(t.Context(), "nopenopenope")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
}
if IsInterrupt(err) {
t.Fatalf("Expected command to not be interrupted, but it was")
}
if err == nil {
t.Fatalf("Expected an error, got nil")
}
}
func TestRunContinuity(t *testing.T) {
shell := newPersistentShell(t.TempDir())
shell.Exec(t.Context(), "export FOO=bar")
dst := t.TempDir()
shell.Exec(t.Context(), "cd "+dst)
out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd")
expect := "bar\n" + dst + "\n"
if out != expect {
t.Fatalf("Expected output %q, got %q", expect, out)
}
}
}

View File

@@ -1,10 +1,12 @@
package permission
import (
"context"
"errors"
"path/filepath"
"slices"
"sync"
"time"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -44,9 +46,11 @@ type Service interface {
type permissionService struct {
*pubsub.Broker[PermissionRequest]
sessionPermissions []PermissionRequest
pendingRequests sync.Map
autoApproveSessions []string
sessionPermissions []PermissionRequest
sessionPermissionsMu sync.RWMutex
pendingRequests sync.Map
autoApproveSessions []string
autoApproveSessionsMu sync.RWMutex
}
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
@@ -54,7 +58,10 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
if ok {
respCh.(chan bool) <- true
}
s.sessionPermissionsMu.Lock()
s.sessionPermissions = append(s.sessionPermissions, permission)
s.sessionPermissionsMu.Unlock()
}
func (s *permissionService) Grant(permission PermissionRequest) {
@@ -72,9 +79,14 @@ func (s *permissionService) Deny(permission PermissionRequest) {
}
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
if slices.Contains(s.autoApproveSessions, opts.SessionID) {
s.autoApproveSessionsMu.RLock()
autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID)
s.autoApproveSessionsMu.RUnlock()
if autoApprove {
return true
}
dir := filepath.Dir(opts.Path)
if dir == "." {
dir = config.WorkingDirectory()
@@ -89,11 +101,14 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
Params: opts.Params,
}
s.sessionPermissionsMu.RLock()
for _, p := range s.sessionPermissions {
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
s.sessionPermissionsMu.RUnlock()
return true
}
}
s.sessionPermissionsMu.RUnlock()
respCh := make(chan bool, 1)
@@ -102,13 +117,22 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
s.Publish(pubsub.CreatedEvent, permission)
// Wait for the response with a timeout
resp := <-respCh
return resp
// Wait for the response with a timeout to prevent indefinite blocking
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
select {
case resp := <-respCh:
return resp
case <-ctx.Done():
return false // Timeout - deny by default
}
}
func (s *permissionService) AutoApproveSession(sessionID string) {
s.autoApproveSessionsMu.Lock()
s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
s.autoApproveSessionsMu.Unlock()
}
func NewPermissionService() Service {

View File

@@ -111,6 +111,8 @@ func (b *Broker[T]) Publish(t EventType, payload T) {
select {
case sub <- event:
default:
// Channel is full, subscriber is slow - skip this event
// This prevents blocking the publisher
}
}
}