Merge pull request #87 from charmbracelet/crush-shell

Move shell to its own package
This commit is contained in:
Kujtim Hoxha
2025-06-30 18:28:09 +02:00
committed by GitHub
6 changed files with 213 additions and 68 deletions

View File

@@ -9,8 +9,8 @@ import (
"time"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/llm/tools/shell"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/shell"
)
type BashParams struct {

View File

@@ -9,7 +9,7 @@ import (
)
func TestShellPerformanceComparison(t *testing.T) {
shell := newPersistentShell(t.TempDir())
shell := NewShell(&Options{WorkingDir: t.TempDir()})
// Test quick command
start := time.Now()
@@ -27,7 +27,7 @@ func TestShellPerformanceComparison(t *testing.T) {
// Benchmark CPU usage during polling
func BenchmarkShellPolling(b *testing.B) {
shell := newPersistentShell(b.TempDir())
shell := NewShell(&Options{WorkingDir: b.TempDir()})
b.ReportAllocs()

30
internal/shell/doc.go Normal file
View File

@@ -0,0 +1,30 @@
package shell
// Example usage of the shell package:
//
// 1. For one-off commands:
//
// shell := shell.NewShell(nil)
// stdout, stderr, err := shell.Exec(context.Background(), "echo hello")
//
// 2. For maintaining state across commands:
//
// shell := shell.NewShell(&shell.Options{
// WorkingDir: "/tmp",
// Logger: myLogger,
// })
// shell.Exec(ctx, "export FOO=bar")
// shell.Exec(ctx, "echo $FOO") // Will print "bar"
//
// 3. For the singleton persistent shell (used by tools):
//
// shell := shell.GetPersistentShell("/path/to/cwd")
// stdout, stderr, err := shell.Exec(ctx, "ls -la")
//
// 4. Managing environment and working directory:
//
// shell := shell.NewShell(nil)
// shell.SetEnv("MY_VAR", "value")
// shell.SetWorkingDir("/tmp")
// cwd := shell.GetWorkingDir()
// env := shell.GetEnv()

View File

@@ -0,0 +1,38 @@
package shell
import (
"sync"
"github.com/charmbracelet/crush/internal/logging"
)
// PersistentShell is a singleton shell instance that maintains state across the application
type PersistentShell struct {
*Shell
}
var (
once sync.Once
shellInstance *PersistentShell
)
// GetPersistentShell returns the singleton persistent shell instance
// This maintains backward compatibility with the existing API
func GetPersistentShell(cwd string) *PersistentShell {
once.Do(func() {
shellInstance = &PersistentShell{
Shell: NewShell(&Options{
WorkingDir: cwd,
Logger: &loggingAdapter{},
}),
}
})
return shellInstance
}
// loggingAdapter adapts the internal logging package to the Logger interface
type loggingAdapter struct{}
func (l *loggingAdapter) InfoPersist(msg string, keysAndValues ...interface{}) {
logging.InfoPersist(msg, keysAndValues...)
}

View File

@@ -1,11 +1,12 @@
// Package shell provides cross-platform shell execution capabilities.
//
// This package offers two main types:
// - Shell: A general-purpose shell executor for one-off or managed commands
// - PersistentShell: A singleton shell that maintains state across the application
//
// WINDOWS COMPATIBILITY:
// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility:
// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands
// - Cross-platform: Falls back to POSIX emulation for Unix-style commands
// - Automatic detection: Chooses the best shell based on command and platform
// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
package shell
import (
@@ -19,7 +20,6 @@ import (
"strings"
"sync"
"github.com/charmbracelet/crush/internal/logging"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/interp"
"mvdan.cc/sh/v3/syntax"
@@ -34,16 +34,123 @@ const (
ShellTypePowerShell
)
type PersistentShell struct {
env []string
cwd string
mu sync.Mutex
// Logger interface for optional logging
type Logger interface {
InfoPersist(msg string, keysAndValues ...interface{})
}
var (
once sync.Once
shellInstance *PersistentShell
)
// noopLogger is a logger that does nothing
type noopLogger struct{}
func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
// Shell provides cross-platform shell execution with optional state persistence
type Shell struct {
env []string
cwd string
mu sync.Mutex
logger Logger
}
// Options for creating a new shell
type Options struct {
WorkingDir string
Env []string
Logger Logger
}
// NewShell creates a new shell instance with the given options
func NewShell(opts *Options) *Shell {
if opts == nil {
opts = &Options{}
}
cwd := opts.WorkingDir
if cwd == "" {
cwd, _ = os.Getwd()
}
env := opts.Env
if env == nil {
env = os.Environ()
}
logger := opts.Logger
if logger == nil {
logger = noopLogger{}
}
return &Shell{
cwd: cwd,
env: env,
logger: logger,
}
}
// Exec executes a command in the shell
func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Determine which shell to use based on platform and command
shellType := s.determineShellType(command)
switch shellType {
case ShellTypeCmd:
return s.execWindows(ctx, command, "cmd")
case ShellTypePowerShell:
return s.execWindows(ctx, command, "powershell")
default:
return s.execPOSIX(ctx, command)
}
}
// GetWorkingDir returns the current working directory
func (s *Shell) GetWorkingDir() string {
s.mu.Lock()
defer s.mu.Unlock()
return s.cwd
}
// SetWorkingDir sets the working directory
func (s *Shell) SetWorkingDir(dir string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Verify the directory exists
if _, err := os.Stat(dir); err != nil {
return fmt.Errorf("directory does not exist: %w", err)
}
s.cwd = dir
return nil
}
// GetEnv returns a copy of the environment variables
func (s *Shell) GetEnv() []string {
s.mu.Lock()
defer s.mu.Unlock()
env := make([]string, len(s.env))
copy(env, s.env)
return env
}
// SetEnv sets an environment variable
func (s *Shell) SetEnv(key, value string) {
s.mu.Lock()
defer s.mu.Unlock()
// Update or add the environment variable
keyPrefix := key + "="
for i, env := range s.env {
if strings.HasPrefix(env, keyPrefix) {
s.env[i] = keyPrefix + value
return
}
}
s.env = append(s.env, keyPrefix+value)
}
// Windows-specific commands that should use native shell
var windowsNativeCommands = map[string]bool{
@@ -66,39 +173,8 @@ var windowsNativeCommands = map[string]bool{
"wmic": true,
}
func GetPersistentShell(cwd string) *PersistentShell {
once.Do(func() {
shellInstance = newPersistentShell(cwd)
})
return shellInstance
}
func newPersistentShell(cwd string) *PersistentShell {
return &PersistentShell{
cwd: cwd,
env: os.Environ(),
}
}
func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Determine which shell to use based on platform and command
shellType := s.determineShellType(command)
switch shellType {
case ShellTypeCmd:
return s.execWindows(ctx, command, "cmd")
case ShellTypePowerShell:
return s.execWindows(ctx, command, "powershell")
default:
return s.execPOSIX(ctx, command)
}
}
// determineShellType decides which shell to use based on platform and command
func (s *PersistentShell) determineShellType(command string) ShellType {
func (s *Shell) determineShellType(command string) ShellType {
if runtime.GOOS != "windows" {
return ShellTypePOSIX
}
@@ -128,7 +204,7 @@ func (s *PersistentShell) determineShellType(command string) ShellType {
}
// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
var cmd *exec.Cmd
// Handle directory changes specially to maintain persistent shell behavior
@@ -160,12 +236,12 @@ func (s *PersistentShell) execWindows(ctx context.Context, command string, shell
err := cmd.Run()
logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
return stdout.String(), stderr.String(), err
}
// handleWindowsCD handles directory changes for Windows shells
func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) {
func (s *Shell) handleWindowsCD(command string) (string, string, error) {
// Extract the target directory from the cd command
parts := strings.Fields(command)
if len(parts) < 2 {
@@ -203,7 +279,7 @@ func (s *PersistentShell) handleWindowsCD(command string) (string, string, error
}
// execPOSIX executes commands using POSIX shell emulation (cross-platform)
func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) {
func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
if err != nil {
return "", "", fmt.Errorf("could not parse command: %w", err)
@@ -226,15 +302,17 @@ func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string
for name, vr := range runner.Vars {
s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
}
logging.InfoPersist("POSIX command finished", "command", command, "err", err)
s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
return stdout.String(), stderr.String(), err
}
// IsInterrupt checks if an error is due to interruption
func IsInterrupt(err error) bool {
return errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded)
}
// ExitCode extracts the exit code from an error
func ExitCode(err error) int {
if err == nil {
return 0

View File

@@ -10,7 +10,7 @@ import (
// Benchmark to measure CPU efficiency
func BenchmarkShellQuickCommands(b *testing.B) {
shell := newPersistentShell(b.TempDir())
shell := NewShell(&Options{WorkingDir: b.TempDir()})
b.ReportAllocs()
@@ -27,7 +27,7 @@ func TestTestTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
t.Cleanup(cancel)
shell := newPersistentShell(t.TempDir())
shell := NewShell(&Options{WorkingDir: t.TempDir()})
_, _, err := shell.Exec(ctx, "sleep 10")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -44,7 +44,7 @@ func TestTestCancel(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
cancel() // immediately cancel the context
shell := newPersistentShell(t.TempDir())
shell := NewShell(&Options{WorkingDir: t.TempDir()})
_, _, err := shell.Exec(ctx, "sleep 10")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -58,7 +58,7 @@ func TestTestCancel(t *testing.T) {
}
func TestRunCommandError(t *testing.T) {
shell := newPersistentShell(t.TempDir())
shell := NewShell(&Options{WorkingDir: t.TempDir()})
_, _, err := shell.Exec(t.Context(), "nopenopenope")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -72,7 +72,7 @@ func TestRunCommandError(t *testing.T) {
}
func TestRunContinuity(t *testing.T) {
shell := newPersistentShell(t.TempDir())
shell := NewShell(&Options{WorkingDir: t.TempDir()})
shell.Exec(t.Context(), "export FOO=bar")
dst := t.TempDir()
shell.Exec(t.Context(), "cd "+dst)
@@ -141,10 +141,9 @@ func TestWindowsCDHandling(t *testing.T) {
t.Skip("Windows CD handling test only runs on Windows")
}
shell := &PersistentShell{
cwd: "C:\\Users",
env: []string{},
}
shell := NewShell(&Options{
WorkingDir: "C:\\Users",
})
tests := []struct {
command string
@@ -159,7 +158,7 @@ func TestWindowsCDHandling(t *testing.T) {
for _, test := range tests {
t.Run(test.command, func(t *testing.T) {
originalCwd := shell.cwd
originalCwd := shell.GetWorkingDir()
stdout, stderr, err := shell.handleWindowsCD(test.command)
if test.shouldError {
@@ -170,13 +169,13 @@ func TestWindowsCDHandling(t *testing.T) {
if err != nil {
t.Errorf("Command %q failed: %v", test.command, err)
}
if shell.cwd != test.expectedCwd {
t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.cwd)
if shell.GetWorkingDir() != test.expectedCwd {
t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.GetWorkingDir())
}
}
// Reset for next test
shell.cwd = originalCwd
shell.SetWorkingDir(originalCwd)
_ = stdout
_ = stderr
})
@@ -184,7 +183,7 @@ func TestWindowsCDHandling(t *testing.T) {
}
func TestCrossPlatformExecution(t *testing.T) {
shell := newPersistentShell(".")
shell := NewShell(&Options{WorkingDir: "."})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -209,7 +208,7 @@ func TestWindowsNativeCommands(t *testing.T) {
t.Skip("Windows native command test only runs on Windows")
}
shell := newPersistentShell(".")
shell := NewShell(&Options{WorkingDir: "."})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()