mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
perf: concurrency improvements
This commit is contained in:
@@ -13,8 +13,7 @@ func TestShellPerformanceComparison(t *testing.T) {
|
||||
|
||||
// Test quick command
|
||||
start := time.Now()
|
||||
stdout, stderr, err := shell.Exec(t.Context(), "echo 'hello'")
|
||||
exitCode := ExitCode(err)
|
||||
stdout, stderr, exitCode, _, err := shell.Exec(t.Context(), "echo 'hello'", 0)
|
||||
duration := time.Since(start)
|
||||
|
||||
require.NoError(t, err)
|
||||
@@ -33,8 +32,7 @@ func BenchmarkShellPolling(b *testing.B) {
|
||||
|
||||
for b.Loop() {
|
||||
// Use a short sleep to measure polling overhead
|
||||
_, _, err := shell.Exec(b.Context(), "sleep 0.02")
|
||||
exitCode := ExitCode(err)
|
||||
_, _, exitCode, _, err := shell.Exec(b.Context(), "sleep 0.02", 0)
|
||||
if err != nil || exitCode != 0 {
|
||||
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
|
||||
}
|
||||
|
||||
@@ -1,87 +1,323 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/interp"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PersistentShell struct {
|
||||
env []string
|
||||
cwd string
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
stdin *os.File
|
||||
isAlive bool
|
||||
cwd string
|
||||
mu sync.Mutex
|
||||
commandQueue chan *commandExecution
|
||||
}
|
||||
|
||||
type commandExecution struct {
|
||||
command string
|
||||
timeout time.Duration
|
||||
resultChan chan commandResult
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
stdout string
|
||||
stderr string
|
||||
exitCode int
|
||||
interrupted bool
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
shellInstance *PersistentShell
|
||||
shellInstance *PersistentShell
|
||||
shellInstanceOnce sync.Once
|
||||
)
|
||||
|
||||
func GetPersistentShell(cwd string) *PersistentShell {
|
||||
once.Do(func() {
|
||||
shellInstance = newPersistentShell(cwd)
|
||||
func GetPersistentShell(workingDir string) *PersistentShell {
|
||||
shellInstanceOnce.Do(func() {
|
||||
shellInstance = newPersistentShell(workingDir)
|
||||
})
|
||||
|
||||
if shellInstance == nil {
|
||||
shellInstance = newPersistentShell(workingDir)
|
||||
} else if !shellInstance.isAlive {
|
||||
shellInstance = newPersistentShell(shellInstance.cwd)
|
||||
}
|
||||
|
||||
return shellInstance
|
||||
}
|
||||
|
||||
func newPersistentShell(cwd string) *PersistentShell {
|
||||
return &PersistentShell{
|
||||
cwd: cwd,
|
||||
env: os.Environ(),
|
||||
// Default to environment variable
|
||||
shellPath := os.Getenv("SHELL")
|
||||
if shellPath == "" {
|
||||
shellPath = "/bin/bash"
|
||||
}
|
||||
|
||||
// Default shell args
|
||||
shellArgs := []string{"-l"}
|
||||
|
||||
cmd := exec.Command(shellPath, shellArgs...)
|
||||
cmd.Dir = cwd
|
||||
|
||||
stdinPipe, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
shell := &PersistentShell{
|
||||
cmd: cmd,
|
||||
stdin: stdinPipe.(*os.File),
|
||||
isAlive: true,
|
||||
cwd: cwd,
|
||||
commandQueue: make(chan *commandExecution, 10),
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
|
||||
shell.isAlive = false
|
||||
close(shell.commandQueue)
|
||||
}
|
||||
}()
|
||||
shell.processCommands()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
if err != nil {
|
||||
// Log the error if needed
|
||||
}
|
||||
shell.isAlive = false
|
||||
close(shell.commandQueue)
|
||||
}()
|
||||
|
||||
return shell
|
||||
}
|
||||
|
||||
func (s *PersistentShell) processCommands() {
|
||||
for cmd := range s.commandQueue {
|
||||
result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
|
||||
cmd.resultChan <- result
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
|
||||
func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("could not parse command: %w", err)
|
||||
if !s.isAlive {
|
||||
return commandResult{
|
||||
stderr: "Shell is not alive",
|
||||
exitCode: 1,
|
||||
err: errors.New("shell is not alive"),
|
||||
}
|
||||
}
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
runner, err := interp.New(
|
||||
interp.StdIO(nil, &stdout, &stderr),
|
||||
interp.Interactive(false),
|
||||
interp.Env(expand.ListEnviron(s.env...)),
|
||||
interp.Dir(s.cwd),
|
||||
tempDir := os.TempDir()
|
||||
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
|
||||
stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
|
||||
statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
|
||||
cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
|
||||
|
||||
defer func() {
|
||||
os.Remove(stdoutFile)
|
||||
os.Remove(stderrFile)
|
||||
os.Remove(statusFile)
|
||||
os.Remove(cwdFile)
|
||||
}()
|
||||
|
||||
fullCommand := fmt.Sprintf(`
|
||||
eval %s < /dev/null > %s 2> %s
|
||||
EXEC_EXIT_CODE=$?
|
||||
pwd > %s
|
||||
echo $EXEC_EXIT_CODE > %s
|
||||
`,
|
||||
shellQuote(command),
|
||||
shellQuote(stdoutFile),
|
||||
shellQuote(stderrFile),
|
||||
shellQuote(cwdFile),
|
||||
shellQuote(statusFile),
|
||||
)
|
||||
|
||||
_, err := s.stdin.Write([]byte(fullCommand + "\n"))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("could not run command: %w", err)
|
||||
return commandResult{
|
||||
stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
|
||||
exitCode: 1,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
err = runner.Run(ctx, line)
|
||||
s.cwd = runner.Dir
|
||||
s.env = []string{}
|
||||
for name, vr := range runner.Vars {
|
||||
s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
|
||||
interrupted := false
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
// Use exponential backoff polling
|
||||
pollInterval := 1 * time.Millisecond
|
||||
maxPollInterval := 100 * time.Millisecond
|
||||
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.killChildren()
|
||||
interrupted = true
|
||||
done <- true
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
if fileExists(statusFile) && fileSize(statusFile) > 0 {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
|
||||
if timeout > 0 {
|
||||
elapsed := time.Since(startTime)
|
||||
if elapsed > timeout {
|
||||
s.killChildren()
|
||||
interrupted = true
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential backoff to reduce CPU usage for longer-running commands
|
||||
if pollInterval < maxPollInterval {
|
||||
pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
|
||||
ticker.Reset(pollInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
|
||||
stdout := readFileOrEmpty(stdoutFile)
|
||||
stderr := readFileOrEmpty(stderrFile)
|
||||
exitCodeStr := readFileOrEmpty(statusFile)
|
||||
newCwd := readFileOrEmpty(cwdFile)
|
||||
|
||||
exitCode := 0
|
||||
if exitCodeStr != "" {
|
||||
fmt.Sscanf(exitCodeStr, "%d", &exitCode)
|
||||
} else if interrupted {
|
||||
exitCode = 143
|
||||
stderr += "\nCommand execution timed out or was interrupted"
|
||||
}
|
||||
|
||||
if newCwd != "" {
|
||||
s.cwd = strings.TrimSpace(newCwd)
|
||||
}
|
||||
|
||||
return commandResult{
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
exitCode: exitCode,
|
||||
interrupted: interrupted,
|
||||
}
|
||||
logging.InfoPersist("Command finished", "command", command, "err", err)
|
||||
return stdout.String(), stderr.String(), err
|
||||
}
|
||||
|
||||
func IsInterrupt(err error) bool {
|
||||
return errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, context.DeadlineExceeded)
|
||||
func (s *PersistentShell) killChildren() {
|
||||
if s.cmd == nil || s.cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
|
||||
output, err := pgrepCmd.Output()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for pidStr := range strings.SplitSeq(string(output), "\n") {
|
||||
if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
|
||||
var pid int
|
||||
fmt.Sscanf(pidStr, "%d", &pid)
|
||||
if pid > 0 {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err == nil {
|
||||
proc.Signal(syscall.SIGTERM)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExitCode(err error) int {
|
||||
if err == nil {
|
||||
func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
|
||||
if !s.isAlive {
|
||||
return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutMs) * time.Millisecond
|
||||
|
||||
resultChan := make(chan commandResult)
|
||||
s.commandQueue <- &commandExecution{
|
||||
command: command,
|
||||
timeout: timeout,
|
||||
resultChan: resultChan,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
result := <-resultChan
|
||||
return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
|
||||
}
|
||||
|
||||
func (s *PersistentShell) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.isAlive {
|
||||
return
|
||||
}
|
||||
|
||||
s.stdin.Write([]byte("exit\n"))
|
||||
|
||||
s.cmd.Process.Kill()
|
||||
s.isAlive = false
|
||||
}
|
||||
|
||||
func shellQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
func readFileOrEmpty(path string) string {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(content)
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func fileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
status, ok := interp.IsExitStatus(err)
|
||||
if ok {
|
||||
return int(status)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
@@ -2,81 +2,28 @@ package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Benchmark to measure CPU efficiency
|
||||
func BenchmarkShellQuickCommands(b *testing.B) {
|
||||
shell := newPersistentShell(b.TempDir())
|
||||
tmpDir, err := os.MkdirTemp("", "shell-bench")
|
||||
require.NoError(b, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
shell := GetPersistentShell(tmpDir)
|
||||
defer shell.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
_, _, err := shell.Exec(context.Background(), "echo test")
|
||||
exitCode := ExitCode(err)
|
||||
_, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0)
|
||||
if err != nil || exitCode != 0 {
|
||||
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
_, _, err := shell.Exec(ctx, "sleep 10")
|
||||
if status := ExitCode(err); status == 0 {
|
||||
t.Fatalf("Expected non-zero exit status, got %d", status)
|
||||
}
|
||||
if !IsInterrupt(err) {
|
||||
t.Fatalf("Expected command to be interrupted, but it was not")
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("Expected an error due to timeout, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel() // immediately cancel the context
|
||||
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
_, _, err := shell.Exec(ctx, "sleep 10")
|
||||
if status := ExitCode(err); status == 0 {
|
||||
t.Fatalf("Expected non-zero exit status, got %d", status)
|
||||
}
|
||||
if !IsInterrupt(err) {
|
||||
t.Fatalf("Expected command to be interrupted, but it was not")
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("Expected an error due to cancel, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCommandError(t *testing.T) {
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
_, _, err := shell.Exec(t.Context(), "nopenopenope")
|
||||
if status := ExitCode(err); status == 0 {
|
||||
t.Fatalf("Expected non-zero exit status, got %d", status)
|
||||
}
|
||||
if IsInterrupt(err) {
|
||||
t.Fatalf("Expected command to not be interrupted, but it was")
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("Expected an error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunContinuity(t *testing.T) {
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
shell.Exec(t.Context(), "export FOO=bar")
|
||||
dst := t.TempDir()
|
||||
shell.Exec(t.Context(), "cd "+dst)
|
||||
out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd")
|
||||
expect := "bar\n" + dst + "\n"
|
||||
if out != expect {
|
||||
t.Fatalf("Expected output %q, got %q", expect, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
@@ -44,9 +46,11 @@ type Service interface {
|
||||
type permissionService struct {
|
||||
*pubsub.Broker[PermissionRequest]
|
||||
|
||||
sessionPermissions []PermissionRequest
|
||||
pendingRequests sync.Map
|
||||
autoApproveSessions []string
|
||||
sessionPermissions []PermissionRequest
|
||||
sessionPermissionsMu sync.RWMutex
|
||||
pendingRequests sync.Map
|
||||
autoApproveSessions []string
|
||||
autoApproveSessionsMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
|
||||
@@ -54,7 +58,10 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
}
|
||||
|
||||
s.sessionPermissionsMu.Lock()
|
||||
s.sessionPermissions = append(s.sessionPermissions, permission)
|
||||
s.sessionPermissionsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *permissionService) Grant(permission PermissionRequest) {
|
||||
@@ -72,9 +79,14 @@ func (s *permissionService) Deny(permission PermissionRequest) {
|
||||
}
|
||||
|
||||
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
if slices.Contains(s.autoApproveSessions, opts.SessionID) {
|
||||
s.autoApproveSessionsMu.RLock()
|
||||
autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID)
|
||||
s.autoApproveSessionsMu.RUnlock()
|
||||
|
||||
if autoApprove {
|
||||
return true
|
||||
}
|
||||
|
||||
dir := filepath.Dir(opts.Path)
|
||||
if dir == "." {
|
||||
dir = config.WorkingDirectory()
|
||||
@@ -89,11 +101,14 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
Params: opts.Params,
|
||||
}
|
||||
|
||||
s.sessionPermissionsMu.RLock()
|
||||
for _, p := range s.sessionPermissions {
|
||||
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
|
||||
s.sessionPermissionsMu.RUnlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
s.sessionPermissionsMu.RUnlock()
|
||||
|
||||
respCh := make(chan bool, 1)
|
||||
|
||||
@@ -102,13 +117,22 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
|
||||
s.Publish(pubsub.CreatedEvent, permission)
|
||||
|
||||
// Wait for the response with a timeout
|
||||
resp := <-respCh
|
||||
return resp
|
||||
// Wait for the response with a timeout to prevent indefinite blocking
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
return resp
|
||||
case <-ctx.Done():
|
||||
return false // Timeout - deny by default
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) AutoApproveSession(sessionID string) {
|
||||
s.autoApproveSessionsMu.Lock()
|
||||
s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
|
||||
s.autoApproveSessionsMu.Unlock()
|
||||
}
|
||||
|
||||
func NewPermissionService() Service {
|
||||
|
||||
@@ -111,6 +111,8 @@ func (b *Broker[T]) Publish(t EventType, payload T) {
|
||||
select {
|
||||
case sub <- event:
|
||||
default:
|
||||
// Channel is full, subscriber is slow - skip this event
|
||||
// This prevents blocking the publisher
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user