mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
Merge pull request #87 from charmbracelet/crush-shell
Move shell to its own package
This commit is contained in:
@@ -9,8 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools/shell"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
"github.com/charmbracelet/crush/internal/shell"
|
||||
)
|
||||
|
||||
type BashParams struct {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestShellPerformanceComparison(t *testing.T) {
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
shell := NewShell(&Options{WorkingDir: t.TempDir()})
|
||||
|
||||
// Test quick command
|
||||
start := time.Now()
|
||||
@@ -27,7 +27,7 @@ func TestShellPerformanceComparison(t *testing.T) {
|
||||
|
||||
// Benchmark CPU usage during polling
|
||||
func BenchmarkShellPolling(b *testing.B) {
|
||||
shell := newPersistentShell(b.TempDir())
|
||||
shell := NewShell(&Options{WorkingDir: b.TempDir()})
|
||||
|
||||
b.ReportAllocs()
|
||||
|
||||
30
internal/shell/doc.go
Normal file
30
internal/shell/doc.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package shell
|
||||
|
||||
// Example usage of the shell package:
|
||||
//
|
||||
// 1. For one-off commands:
|
||||
//
|
||||
// shell := shell.NewShell(nil)
|
||||
// stdout, stderr, err := shell.Exec(context.Background(), "echo hello")
|
||||
//
|
||||
// 2. For maintaining state across commands:
|
||||
//
|
||||
// shell := shell.NewShell(&shell.Options{
|
||||
// WorkingDir: "/tmp",
|
||||
// Logger: myLogger,
|
||||
// })
|
||||
// shell.Exec(ctx, "export FOO=bar")
|
||||
// shell.Exec(ctx, "echo $FOO") // Will print "bar"
|
||||
//
|
||||
// 3. For the singleton persistent shell (used by tools):
|
||||
//
|
||||
// shell := shell.GetPersistentShell("/path/to/cwd")
|
||||
// stdout, stderr, err := shell.Exec(ctx, "ls -la")
|
||||
//
|
||||
// 4. Managing environment and working directory:
|
||||
//
|
||||
// shell := shell.NewShell(nil)
|
||||
// shell.SetEnv("MY_VAR", "value")
|
||||
// shell.SetWorkingDir("/tmp")
|
||||
// cwd := shell.GetWorkingDir()
|
||||
// env := shell.GetEnv()
|
||||
38
internal/shell/persistent.go
Normal file
38
internal/shell/persistent.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
)
|
||||
|
||||
// PersistentShell is a singleton shell instance that maintains state across the application
|
||||
type PersistentShell struct {
|
||||
*Shell
|
||||
}
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
shellInstance *PersistentShell
|
||||
)
|
||||
|
||||
// GetPersistentShell returns the singleton persistent shell instance
|
||||
// This maintains backward compatibility with the existing API
|
||||
func GetPersistentShell(cwd string) *PersistentShell {
|
||||
once.Do(func() {
|
||||
shellInstance = &PersistentShell{
|
||||
Shell: NewShell(&Options{
|
||||
WorkingDir: cwd,
|
||||
Logger: &loggingAdapter{},
|
||||
}),
|
||||
}
|
||||
})
|
||||
return shellInstance
|
||||
}
|
||||
|
||||
// loggingAdapter adapts the internal logging package to the Logger interface
|
||||
type loggingAdapter struct{}
|
||||
|
||||
func (l *loggingAdapter) InfoPersist(msg string, keysAndValues ...interface{}) {
|
||||
logging.InfoPersist(msg, keysAndValues...)
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
// Package shell provides cross-platform shell execution capabilities.
|
||||
//
|
||||
// This package offers two main types:
|
||||
// - Shell: A general-purpose shell executor for one-off or managed commands
|
||||
// - PersistentShell: A singleton shell that maintains state across the application
|
||||
//
|
||||
// WINDOWS COMPATIBILITY:
|
||||
// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
|
||||
// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility:
|
||||
// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands
|
||||
// - Cross-platform: Falls back to POSIX emulation for Unix-style commands
|
||||
// - Automatic detection: Chooses the best shell based on command and platform
|
||||
// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
|
||||
package shell
|
||||
|
||||
import (
|
||||
@@ -19,7 +20,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/logging"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/interp"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
@@ -34,16 +34,123 @@ const (
|
||||
ShellTypePowerShell
|
||||
)
|
||||
|
||||
type PersistentShell struct {
|
||||
env []string
|
||||
cwd string
|
||||
mu sync.Mutex
|
||||
// Logger interface for optional logging
|
||||
type Logger interface {
|
||||
InfoPersist(msg string, keysAndValues ...interface{})
|
||||
}
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
shellInstance *PersistentShell
|
||||
)
|
||||
// noopLogger is a logger that does nothing
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
|
||||
|
||||
// Shell provides cross-platform shell execution with optional state persistence
|
||||
type Shell struct {
|
||||
env []string
|
||||
cwd string
|
||||
mu sync.Mutex
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// Options for creating a new shell
|
||||
type Options struct {
|
||||
WorkingDir string
|
||||
Env []string
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// NewShell creates a new shell instance with the given options
|
||||
func NewShell(opts *Options) *Shell {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
cwd := opts.WorkingDir
|
||||
if cwd == "" {
|
||||
cwd, _ = os.Getwd()
|
||||
}
|
||||
|
||||
env := opts.Env
|
||||
if env == nil {
|
||||
env = os.Environ()
|
||||
}
|
||||
|
||||
logger := opts.Logger
|
||||
if logger == nil {
|
||||
logger = noopLogger{}
|
||||
}
|
||||
|
||||
return &Shell{
|
||||
cwd: cwd,
|
||||
env: env,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Exec executes a command in the shell
|
||||
func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Determine which shell to use based on platform and command
|
||||
shellType := s.determineShellType(command)
|
||||
|
||||
switch shellType {
|
||||
case ShellTypeCmd:
|
||||
return s.execWindows(ctx, command, "cmd")
|
||||
case ShellTypePowerShell:
|
||||
return s.execWindows(ctx, command, "powershell")
|
||||
default:
|
||||
return s.execPOSIX(ctx, command)
|
||||
}
|
||||
}
|
||||
|
||||
// GetWorkingDir returns the current working directory
|
||||
func (s *Shell) GetWorkingDir() string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.cwd
|
||||
}
|
||||
|
||||
// SetWorkingDir sets the working directory
|
||||
func (s *Shell) SetWorkingDir(dir string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Verify the directory exists
|
||||
if _, err := os.Stat(dir); err != nil {
|
||||
return fmt.Errorf("directory does not exist: %w", err)
|
||||
}
|
||||
|
||||
s.cwd = dir
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEnv returns a copy of the environment variables
|
||||
func (s *Shell) GetEnv() []string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
env := make([]string, len(s.env))
|
||||
copy(env, s.env)
|
||||
return env
|
||||
}
|
||||
|
||||
// SetEnv sets an environment variable
|
||||
func (s *Shell) SetEnv(key, value string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Update or add the environment variable
|
||||
keyPrefix := key + "="
|
||||
for i, env := range s.env {
|
||||
if strings.HasPrefix(env, keyPrefix) {
|
||||
s.env[i] = keyPrefix + value
|
||||
return
|
||||
}
|
||||
}
|
||||
s.env = append(s.env, keyPrefix+value)
|
||||
}
|
||||
|
||||
// Windows-specific commands that should use native shell
|
||||
var windowsNativeCommands = map[string]bool{
|
||||
@@ -66,39 +173,8 @@ var windowsNativeCommands = map[string]bool{
|
||||
"wmic": true,
|
||||
}
|
||||
|
||||
func GetPersistentShell(cwd string) *PersistentShell {
|
||||
once.Do(func() {
|
||||
shellInstance = newPersistentShell(cwd)
|
||||
})
|
||||
return shellInstance
|
||||
}
|
||||
|
||||
func newPersistentShell(cwd string) *PersistentShell {
|
||||
return &PersistentShell{
|
||||
cwd: cwd,
|
||||
env: os.Environ(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Determine which shell to use based on platform and command
|
||||
shellType := s.determineShellType(command)
|
||||
|
||||
switch shellType {
|
||||
case ShellTypeCmd:
|
||||
return s.execWindows(ctx, command, "cmd")
|
||||
case ShellTypePowerShell:
|
||||
return s.execWindows(ctx, command, "powershell")
|
||||
default:
|
||||
return s.execPOSIX(ctx, command)
|
||||
}
|
||||
}
|
||||
|
||||
// determineShellType decides which shell to use based on platform and command
|
||||
func (s *PersistentShell) determineShellType(command string) ShellType {
|
||||
func (s *Shell) determineShellType(command string) ShellType {
|
||||
if runtime.GOOS != "windows" {
|
||||
return ShellTypePOSIX
|
||||
}
|
||||
@@ -128,7 +204,7 @@ func (s *PersistentShell) determineShellType(command string) ShellType {
|
||||
}
|
||||
|
||||
// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
|
||||
func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
|
||||
func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
// Handle directory changes specially to maintain persistent shell behavior
|
||||
@@ -160,12 +236,12 @@ func (s *PersistentShell) execWindows(ctx context.Context, command string, shell
|
||||
|
||||
err := cmd.Run()
|
||||
|
||||
logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
|
||||
s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
|
||||
return stdout.String(), stderr.String(), err
|
||||
}
|
||||
|
||||
// handleWindowsCD handles directory changes for Windows shells
|
||||
func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) {
|
||||
func (s *Shell) handleWindowsCD(command string) (string, string, error) {
|
||||
// Extract the target directory from the cd command
|
||||
parts := strings.Fields(command)
|
||||
if len(parts) < 2 {
|
||||
@@ -203,7 +279,7 @@ func (s *PersistentShell) handleWindowsCD(command string) (string, string, error
|
||||
}
|
||||
|
||||
// execPOSIX executes commands using POSIX shell emulation (cross-platform)
|
||||
func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) {
|
||||
func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
|
||||
line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("could not parse command: %w", err)
|
||||
@@ -226,15 +302,17 @@ func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string
|
||||
for name, vr := range runner.Vars {
|
||||
s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
|
||||
}
|
||||
logging.InfoPersist("POSIX command finished", "command", command, "err", err)
|
||||
s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
|
||||
return stdout.String(), stderr.String(), err
|
||||
}
|
||||
|
||||
// IsInterrupt checks if an error is due to interruption
|
||||
func IsInterrupt(err error) bool {
|
||||
return errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// ExitCode extracts the exit code from an error
|
||||
func ExitCode(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
// Benchmark to measure CPU efficiency
|
||||
func BenchmarkShellQuickCommands(b *testing.B) {
|
||||
shell := newPersistentShell(b.TempDir())
|
||||
shell := NewShell(&Options{WorkingDir: b.TempDir()})
|
||||
|
||||
b.ReportAllocs()
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestTestTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
shell := NewShell(&Options{WorkingDir: t.TempDir()})
|
||||
_, _, err := shell.Exec(ctx, "sleep 10")
|
||||
if status := ExitCode(err); status == 0 {
|
||||
t.Fatalf("Expected non-zero exit status, got %d", status)
|
||||
@@ -44,7 +44,7 @@ func TestTestCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel() // immediately cancel the context
|
||||
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
shell := NewShell(&Options{WorkingDir: t.TempDir()})
|
||||
_, _, err := shell.Exec(ctx, "sleep 10")
|
||||
if status := ExitCode(err); status == 0 {
|
||||
t.Fatalf("Expected non-zero exit status, got %d", status)
|
||||
@@ -58,7 +58,7 @@ func TestTestCancel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunCommandError(t *testing.T) {
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
shell := NewShell(&Options{WorkingDir: t.TempDir()})
|
||||
_, _, err := shell.Exec(t.Context(), "nopenopenope")
|
||||
if status := ExitCode(err); status == 0 {
|
||||
t.Fatalf("Expected non-zero exit status, got %d", status)
|
||||
@@ -72,7 +72,7 @@ func TestRunCommandError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunContinuity(t *testing.T) {
|
||||
shell := newPersistentShell(t.TempDir())
|
||||
shell := NewShell(&Options{WorkingDir: t.TempDir()})
|
||||
shell.Exec(t.Context(), "export FOO=bar")
|
||||
dst := t.TempDir()
|
||||
shell.Exec(t.Context(), "cd "+dst)
|
||||
@@ -141,10 +141,9 @@ func TestWindowsCDHandling(t *testing.T) {
|
||||
t.Skip("Windows CD handling test only runs on Windows")
|
||||
}
|
||||
|
||||
shell := &PersistentShell{
|
||||
cwd: "C:\\Users",
|
||||
env: []string{},
|
||||
}
|
||||
shell := NewShell(&Options{
|
||||
WorkingDir: "C:\\Users",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
command string
|
||||
@@ -159,7 +158,7 @@ func TestWindowsCDHandling(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.command, func(t *testing.T) {
|
||||
originalCwd := shell.cwd
|
||||
originalCwd := shell.GetWorkingDir()
|
||||
stdout, stderr, err := shell.handleWindowsCD(test.command)
|
||||
|
||||
if test.shouldError {
|
||||
@@ -170,13 +169,13 @@ func TestWindowsCDHandling(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("Command %q failed: %v", test.command, err)
|
||||
}
|
||||
if shell.cwd != test.expectedCwd {
|
||||
t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.cwd)
|
||||
if shell.GetWorkingDir() != test.expectedCwd {
|
||||
t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.GetWorkingDir())
|
||||
}
|
||||
}
|
||||
|
||||
// Reset for next test
|
||||
shell.cwd = originalCwd
|
||||
shell.SetWorkingDir(originalCwd)
|
||||
_ = stdout
|
||||
_ = stderr
|
||||
})
|
||||
@@ -184,7 +183,7 @@ func TestWindowsCDHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCrossPlatformExecution(t *testing.T) {
|
||||
shell := newPersistentShell(".")
|
||||
shell := NewShell(&Options{WorkingDir: "."})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -209,7 +208,7 @@ func TestWindowsNativeCommands(t *testing.T) {
|
||||
t.Skip("Windows native command test only runs on Windows")
|
||||
}
|
||||
|
||||
shell := newPersistentShell(".")
|
||||
shell := NewShell(&Options{WorkingDir: "."})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
Reference in New Issue
Block a user