mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
chore: improve permissions & edit tool
This commit is contained in:
@@ -203,6 +203,7 @@ func (app *App) setupEvents() {
|
||||
setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
|
||||
setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
|
||||
setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
|
||||
setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
|
||||
setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
|
||||
cleanupFunc := func() {
|
||||
cancel()
|
||||
|
||||
@@ -100,6 +100,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
|
||||
p := b.permissions.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
ToolCallID: params.ID,
|
||||
Path: b.workingDir,
|
||||
ToolName: b.Info().Name,
|
||||
Action: "execute",
|
||||
|
||||
@@ -107,7 +107,7 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN
|
||||
|
||||
# Tool usage policy
|
||||
- When doing file search, prefer to use the Agent tool in order to reduce context usage.
|
||||
- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in parallel.
|
||||
- IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them).
|
||||
- IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user.
|
||||
|
||||
# Proactiveness
|
||||
@@ -217,7 +217,7 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN
|
||||
|
||||
# Tool usage policy
|
||||
- When doing file search, prefer to use the Agent tool in order to reduce context usage.
|
||||
- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in parallel.
|
||||
- IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them).
|
||||
- IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user.
|
||||
|
||||
VERY IMPORTANT NEVER use emojis in your responses.
|
||||
@@ -281,7 +281,7 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN
|
||||
|
||||
## Tool Usage
|
||||
- **File Paths:** Always use absolute paths when referring to files with tools like ` + "`view`" + ` or ` + "`write`" + `. Relative paths are not supported. You must provide an absolute path.
|
||||
- **Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase).
|
||||
- **Parallelism:** IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them).
|
||||
- **Command Execution:** Use the ` + "`bash`" + ` tool for running shell commands, remembering the safety rule to explain modifying commands first.
|
||||
- **Background Processes:** Use background processes (via ` + "`&`" + `) for commands that are unlikely to stop on their own, e.g. ` + "`node server.js &`" + `. If unsure, ask the user.
|
||||
- **Interactive Commands:** Try to avoid shell commands that are likely to require user interaction (e.g. ` + "`git rebase -i`" + `). Use non-interactive versions of commands (e.g. ` + "`npm init -y`" + ` instead of ` + "`npm init`" + `) when available, and otherwise remind the user that interactive shell commands are not supported and may cause hangs until canceled by the user.
|
||||
|
||||
@@ -373,6 +373,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: b.workingDir,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: BashToolName,
|
||||
Action: "execute",
|
||||
Description: fmt.Sprintf("Execute command: %s", params.Command),
|
||||
|
||||
@@ -18,9 +18,10 @@ import (
|
||||
)
|
||||
|
||||
type EditParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
OldString string `json:"old_string"`
|
||||
NewString string `json:"new_string"`
|
||||
FilePath string `json:"file_path"`
|
||||
OldString string `json:"old_string"`
|
||||
NewString string `json:"new_string"`
|
||||
ReplaceAll bool `json:"replace_all,omitempty"`
|
||||
}
|
||||
|
||||
type EditPermissionsParams struct {
|
||||
@@ -58,31 +59,33 @@ To make a file edit, provide the following:
|
||||
1. file_path: The absolute path to the file to modify (must be absolute, not relative)
|
||||
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
|
||||
3. new_string: The edited text to replace the old_string
|
||||
4. replace_all: Replace all occurrences of old_string (default false)
|
||||
|
||||
Special cases:
|
||||
- To create a new file: provide file_path and new_string, leave old_string empty
|
||||
- To delete content: provide file_path and old_string, leave new_string empty
|
||||
|
||||
The tool will replace ONE occurrence of old_string with new_string in the specified file.
|
||||
The tool will replace ONE occurrence of old_string with new_string in the specified file by default. Set replace_all to true to replace all occurrences.
|
||||
|
||||
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
|
||||
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
|
||||
1. UNIQUENESS: When replace_all is false (default), the old_string MUST uniquely identify the specific instance you want to change. This means:
|
||||
- Include AT LEAST 3-5 lines of context BEFORE the change point
|
||||
- Include AT LEAST 3-5 lines of context AFTER the change point
|
||||
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
|
||||
|
||||
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
|
||||
- Make separate calls to this tool for each instance
|
||||
2. SINGLE INSTANCE: When replace_all is false, this tool can only change ONE instance at a time. If you need to change multiple instances:
|
||||
- Set replace_all to true to replace all occurrences at once
|
||||
- Or make separate calls to this tool for each instance
|
||||
- Each call must uniquely identify its specific instance using extensive context
|
||||
|
||||
3. VERIFICATION: Before using this tool:
|
||||
- Check how many instances of the target text exist in the file
|
||||
- If multiple instances exist, gather enough context to uniquely identify each one
|
||||
- Plan separate tool calls for each instance
|
||||
- If multiple instances exist and replace_all is false, gather enough context to uniquely identify each one
|
||||
- Plan separate tool calls for each instance or use replace_all
|
||||
|
||||
WARNING: If you do not follow these requirements:
|
||||
- The tool will fail if old_string matches multiple locations
|
||||
- The tool will fail if old_string matches multiple locations and replace_all is false
|
||||
- The tool will fail if old_string doesn't match exactly (including whitespace)
|
||||
- You may change the wrong instance if you don't include enough context
|
||||
|
||||
@@ -129,6 +132,10 @@ func (e *editTool) Info() ToolInfo {
|
||||
"type": "string",
|
||||
"description": "The text to replace it with",
|
||||
},
|
||||
"replace_all": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences of old_string (default false)",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "old_string", "new_string"},
|
||||
}
|
||||
@@ -152,20 +159,20 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
var err error
|
||||
|
||||
if params.OldString == "" {
|
||||
response, err = e.createNewFile(ctx, params.FilePath, params.NewString)
|
||||
response, err = e.createNewFile(ctx, params.FilePath, params.NewString, call)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
}
|
||||
|
||||
if params.NewString == "" {
|
||||
response, err = e.deleteContent(ctx, params.FilePath, params.OldString)
|
||||
response, err = e.deleteContent(ctx, params.FilePath, params.OldString, params.ReplaceAll, call)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
}
|
||||
|
||||
response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString)
|
||||
response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString, params.ReplaceAll, call)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
@@ -182,7 +189,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (ToolResponse, error) {
|
||||
func (e *editTool) createNewFile(ctx context.Context, filePath, content string, call ToolCall) (ToolResponse, error) {
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
if fileInfo.IsDir() {
|
||||
@@ -217,6 +224,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: EditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Create file %s", filePath),
|
||||
@@ -264,7 +272,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
|
||||
), nil
|
||||
}
|
||||
|
||||
func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (ToolResponse, error) {
|
||||
func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
@@ -297,17 +305,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
|
||||
oldContent := string(content)
|
||||
|
||||
index := strings.Index(oldContent, oldString)
|
||||
if index == -1 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
var newContent string
|
||||
var deletionCount int
|
||||
|
||||
lastIndex := strings.LastIndex(oldContent, oldString)
|
||||
if index != lastIndex {
|
||||
return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil
|
||||
}
|
||||
if replaceAll {
|
||||
newContent = strings.ReplaceAll(oldContent, oldString, "")
|
||||
deletionCount = strings.Count(oldContent, oldString)
|
||||
if deletionCount == 0 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
} else {
|
||||
index := strings.Index(oldContent, oldString)
|
||||
if index == -1 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
|
||||
newContent := oldContent[:index] + oldContent[index+len(oldString):]
|
||||
lastIndex := strings.LastIndex(oldContent, oldString)
|
||||
if index != lastIndex {
|
||||
return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
|
||||
}
|
||||
|
||||
newContent = oldContent[:index] + oldContent[index+len(oldString):]
|
||||
deletionCount = 1
|
||||
}
|
||||
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
|
||||
@@ -330,6 +350,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: EditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Delete content from file %s", filePath),
|
||||
@@ -385,7 +406,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
|
||||
), nil
|
||||
}
|
||||
|
||||
func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (ToolResponse, error) {
|
||||
func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
@@ -418,17 +439,29 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
|
||||
oldContent := string(content)
|
||||
|
||||
index := strings.Index(oldContent, oldString)
|
||||
if index == -1 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
var newContent string
|
||||
var replacementCount int
|
||||
|
||||
lastIndex := strings.LastIndex(oldContent, oldString)
|
||||
if index != lastIndex {
|
||||
return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil
|
||||
}
|
||||
if replaceAll {
|
||||
newContent = strings.ReplaceAll(oldContent, oldString, newString)
|
||||
replacementCount = strings.Count(oldContent, oldString)
|
||||
if replacementCount == 0 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
} else {
|
||||
index := strings.Index(oldContent, oldString)
|
||||
if index == -1 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
|
||||
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
|
||||
lastIndex := strings.LastIndex(oldContent, oldString)
|
||||
if index != lastIndex {
|
||||
return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
|
||||
}
|
||||
|
||||
newContent = oldContent[:index] + newString + oldContent[index+len(oldString):]
|
||||
replacementCount = 1
|
||||
}
|
||||
|
||||
if oldContent == newContent {
|
||||
return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
|
||||
@@ -452,6 +485,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: EditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Replace content in file %s", filePath),
|
||||
|
||||
@@ -136,6 +136,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: t.workingDir,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: FetchToolName,
|
||||
Action: "fetch",
|
||||
Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
|
||||
|
||||
467
internal/llm/tools/multiedit.go
Normal file
467
internal/llm/tools/multiedit.go
Normal file
@@ -0,0 +1,467 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/diff"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
)
|
||||
|
||||
type MultiEditOperation struct {
|
||||
OldString string `json:"old_string"`
|
||||
NewString string `json:"new_string"`
|
||||
ReplaceAll bool `json:"replace_all,omitempty"`
|
||||
}
|
||||
|
||||
type MultiEditParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Edits []MultiEditOperation `json:"edits"`
|
||||
}
|
||||
|
||||
type MultiEditPermissionsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
OldContent string `json:"old_content,omitempty"`
|
||||
NewContent string `json:"new_content,omitempty"`
|
||||
}
|
||||
|
||||
type MultiEditResponseMetadata struct {
|
||||
Additions int `json:"additions"`
|
||||
Removals int `json:"removals"`
|
||||
OldContent string `json:"old_content,omitempty"`
|
||||
NewContent string `json:"new_content,omitempty"`
|
||||
EditsApplied int `json:"edits_applied"`
|
||||
}
|
||||
|
||||
type multiEditTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
permissions permission.Service
|
||||
files history.Service
|
||||
workingDir string
|
||||
}
|
||||
|
||||
const (
|
||||
MultiEditToolName = "multiedit"
|
||||
multiEditDescription = `This is a tool for making multiple edits to a single file in one operation. It is built on top of the Edit tool and allows you to perform multiple find-and-replace operations efficiently. Prefer this tool over the Edit tool when you need to make multiple edits to the same file.
|
||||
|
||||
Before using this tool:
|
||||
|
||||
1. Use the Read tool to understand the file's contents and context
|
||||
|
||||
2. Verify the directory path is correct
|
||||
|
||||
To make multiple file edits, provide the following:
|
||||
1. file_path: The absolute path to the file to modify (must be absolute, not relative)
|
||||
2. edits: An array of edit operations to perform, where each edit contains:
|
||||
- old_string: The text to replace (must match the file contents exactly, including all whitespace and indentation)
|
||||
- new_string: The edited text to replace the old_string
|
||||
- replace_all: Replace all occurrences of old_string. This parameter is optional and defaults to false.
|
||||
|
||||
IMPORTANT:
|
||||
- All edits are applied in sequence, in the order they are provided
|
||||
- Each edit operates on the result of the previous edit
|
||||
- All edits must be valid for the operation to succeed - if any edit fails, none will be applied
|
||||
- This tool is ideal when you need to make several changes to different parts of the same file
|
||||
|
||||
CRITICAL REQUIREMENTS:
|
||||
1. All edits follow the same requirements as the single Edit tool
|
||||
2. The edits are atomic - either all succeed or none are applied
|
||||
3. Plan your edits carefully to avoid conflicts between sequential operations
|
||||
|
||||
WARNING:
|
||||
- The tool will fail if edits.old_string doesn't match the file contents exactly (including whitespace)
|
||||
- The tool will fail if edits.old_string and edits.new_string are the same
|
||||
- Since edits are applied in sequence, ensure that earlier edits don't affect the text that later edits are trying to find
|
||||
|
||||
When making edits:
|
||||
- Ensure all edits result in idiomatic, correct code
|
||||
- Do not leave the code in a broken state
|
||||
- Always use absolute file paths (starting with /)
|
||||
- Only use emojis if the user explicitly requests it. Avoid adding emojis to files unless asked.
|
||||
- Use replace_all for replacing and renaming strings across the file. This parameter is useful if you want to rename a variable for instance.
|
||||
|
||||
If you want to create a new file, use:
|
||||
- A new file path, including dir name if needed
|
||||
- First edit: empty old_string and the new file's contents as new_string
|
||||
- Subsequent edits: normal edit operations on the created content`
|
||||
)
|
||||
|
||||
func NewMultiEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool {
|
||||
return &multiEditTool{
|
||||
lspClients: lspClients,
|
||||
permissions: permissions,
|
||||
files: files,
|
||||
workingDir: workingDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multiEditTool) Name() string {
|
||||
return MultiEditToolName
|
||||
}
|
||||
|
||||
func (m *multiEditTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: MultiEditToolName,
|
||||
Description: multiEditDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to modify",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"old_string": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace",
|
||||
},
|
||||
"new_string": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace it with",
|
||||
},
|
||||
"replace_all": map[string]any{
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Replace all occurrences of old_string (default false).",
|
||||
},
|
||||
},
|
||||
"required": []string{"old_string", "new_string"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
"minItems": 1,
|
||||
"description": "Array of edit operations to perform sequentially on the file",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "edits"},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multiEditTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params MultiEditParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("invalid parameters"), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return NewTextErrorResponse("file_path is required"), nil
|
||||
}
|
||||
|
||||
if len(params.Edits) == 0 {
|
||||
return NewTextErrorResponse("at least one edit operation is required"), nil
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(params.FilePath) {
|
||||
params.FilePath = filepath.Join(m.workingDir, params.FilePath)
|
||||
}
|
||||
|
||||
// Validate all edits before applying any
|
||||
if err := m.validateEdits(params.Edits); err != nil {
|
||||
return NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
var response ToolResponse
|
||||
var err error
|
||||
|
||||
// Handle file creation case (first edit has empty old_string)
|
||||
if len(params.Edits) > 0 && params.Edits[0].OldString == "" {
|
||||
response, err = m.processMultiEditWithCreation(ctx, params, call)
|
||||
} else {
|
||||
response, err = m.processMultiEditExistingFile(ctx, params, call)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
if response.IsError {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Wait for LSP diagnostics and add them to the response
|
||||
waitForLspDiagnostics(ctx, params.FilePath, m.lspClients)
|
||||
text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
|
||||
text += getDiagnostics(params.FilePath, m.lspClients)
|
||||
response.Content = text
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (m *multiEditTool) validateEdits(edits []MultiEditOperation) error {
|
||||
for i, edit := range edits {
|
||||
if edit.OldString == edit.NewString {
|
||||
return fmt.Errorf("edit %d: old_string and new_string are identical", i+1)
|
||||
}
|
||||
// Only the first edit can have empty old_string (for file creation)
|
||||
if i > 0 && edit.OldString == "" {
|
||||
return fmt.Errorf("edit %d: only the first edit can have empty old_string (for file creation)", i+1)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiEditTool) processMultiEditWithCreation(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
|
||||
// First edit creates the file
|
||||
firstEdit := params.Edits[0]
|
||||
if firstEdit.OldString != "" {
|
||||
return NewTextErrorResponse("first edit must have empty old_string for file creation"), nil
|
||||
}
|
||||
|
||||
// Check if file already exists
|
||||
if _, err := os.Stat(params.FilePath); err == nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", params.FilePath)), nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
dir := filepath.Dir(params.FilePath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
|
||||
}
|
||||
|
||||
// Start with the content from the first edit
|
||||
currentContent := firstEdit.NewString
|
||||
|
||||
// Apply remaining edits to the content
|
||||
for i := 1; i < len(params.Edits); i++ {
|
||||
edit := params.Edits[i]
|
||||
newContent, err := m.applyEditToContent(currentContent, edit)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
|
||||
}
|
||||
currentContent = newContent
|
||||
}
|
||||
|
||||
// Get session and message IDs
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
|
||||
}
|
||||
|
||||
// Check permissions
|
||||
_, additions, removals := diff.GenerateDiff("", currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
|
||||
rootDir := m.workingDir
|
||||
permissionPath := filepath.Dir(params.FilePath)
|
||||
if strings.HasPrefix(params.FilePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
}
|
||||
|
||||
p := m.permissions.Request(permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: MultiEditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Create file %s with %d edits", params.FilePath, len(params.Edits)),
|
||||
Params: MultiEditPermissionsParams{
|
||||
FilePath: params.FilePath,
|
||||
OldContent: "",
|
||||
NewContent: currentContent,
|
||||
},
|
||||
})
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
// Write the file
|
||||
err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Update file history
|
||||
_, err = m.files.Create(ctx, sessionID, params.FilePath, "")
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
|
||||
_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(params.FilePath)
|
||||
recordFileRead(params.FilePath)
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(fmt.Sprintf("File created with %d edits: %s", len(params.Edits), params.FilePath)),
|
||||
MultiEditResponseMetadata{
|
||||
OldContent: "",
|
||||
NewContent: currentContent,
|
||||
Additions: additions,
|
||||
Removals: removals,
|
||||
EditsApplied: len(params.Edits),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (m *multiEditTool) processMultiEditExistingFile(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
|
||||
// Validate file exists and is readable
|
||||
fileInfo, err := os.Stat(params.FilePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil
|
||||
}
|
||||
return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil
|
||||
}
|
||||
|
||||
// Check if file was read before editing
|
||||
if getLastReadTime(params.FilePath).IsZero() {
|
||||
return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
|
||||
}
|
||||
|
||||
// Check if file was modified since last read
|
||||
modTime := fileInfo.ModTime()
|
||||
lastRead := getLastReadTime(params.FilePath)
|
||||
if modTime.After(lastRead) {
|
||||
return NewTextErrorResponse(
|
||||
fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
|
||||
params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
|
||||
)), nil
|
||||
}
|
||||
|
||||
// Read current file content
|
||||
content, err := os.ReadFile(params.FilePath)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
oldContent := string(content)
|
||||
currentContent := oldContent
|
||||
|
||||
// Apply all edits sequentially
|
||||
for i, edit := range params.Edits {
|
||||
newContent, err := m.applyEditToContent(currentContent, edit)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
|
||||
}
|
||||
currentContent = newContent
|
||||
}
|
||||
|
||||
// Check if content actually changed
|
||||
if oldContent == currentContent {
|
||||
return NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil
|
||||
}
|
||||
|
||||
// Get session and message IDs
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for editing file")
|
||||
}
|
||||
|
||||
// Generate diff and check permissions
|
||||
_, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
|
||||
rootDir := m.workingDir
|
||||
permissionPath := filepath.Dir(params.FilePath)
|
||||
if strings.HasPrefix(params.FilePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
}
|
||||
|
||||
p := m.permissions.Request(permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: MultiEditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Apply %d edits to file %s", len(params.Edits), params.FilePath),
|
||||
Params: MultiEditPermissionsParams{
|
||||
FilePath: params.FilePath,
|
||||
OldContent: oldContent,
|
||||
NewContent: currentContent,
|
||||
},
|
||||
})
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
// Write the updated content
|
||||
err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Update file history
|
||||
file, err := m.files.GetByPathAndSession(ctx, params.FilePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = m.files.Create(ctx, sessionID, params.FilePath, oldContent)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User manually changed the content, store an intermediate version
|
||||
_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store the new version
|
||||
_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(params.FilePath)
|
||||
recordFileRead(params.FilePath)
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(fmt.Sprintf("Applied %d edits to file: %s", len(params.Edits), params.FilePath)),
|
||||
MultiEditResponseMetadata{
|
||||
OldContent: oldContent,
|
||||
NewContent: currentContent,
|
||||
Additions: additions,
|
||||
Removals: removals,
|
||||
EditsApplied: len(params.Edits),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (m *multiEditTool) applyEditToContent(content string, edit MultiEditOperation) (string, error) {
|
||||
if edit.OldString == "" && edit.NewString == "" {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
if edit.OldString == "" {
|
||||
return "", fmt.Errorf("old_string cannot be empty for content replacement")
|
||||
}
|
||||
|
||||
var newContent string
|
||||
var replacementCount int
|
||||
|
||||
if edit.ReplaceAll {
|
||||
newContent = strings.ReplaceAll(content, edit.OldString, edit.NewString)
|
||||
replacementCount = strings.Count(content, edit.OldString)
|
||||
if replacementCount == 0 {
|
||||
return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
|
||||
}
|
||||
} else {
|
||||
index := strings.Index(content, edit.OldString)
|
||||
if index == -1 {
|
||||
return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
|
||||
}
|
||||
|
||||
lastIndex := strings.LastIndex(content, edit.OldString)
|
||||
if index != lastIndex {
|
||||
return "", fmt.Errorf("old_string appears multiple times in the content. Please provide more context to ensure a unique match, or set replace_all to true")
|
||||
}
|
||||
|
||||
newContent = content[:index] + edit.NewString + content[index+len(edit.OldString):]
|
||||
replacementCount = 1
|
||||
}
|
||||
|
||||
return newContent, nil
|
||||
}
|
||||
@@ -181,6 +181,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolCallID: call.ID,
|
||||
ToolName: WriteToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Create file %s", filePath),
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sync"
|
||||
@@ -14,6 +16,7 @@ var ErrorPermissionDenied = errors.New("permission denied")
|
||||
|
||||
type CreatePermissionRequest struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Description string `json:"description"`
|
||||
Action string `json:"action"`
|
||||
@@ -21,9 +24,16 @@ type CreatePermissionRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type PermissionNotification struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Granted bool `json:"granted"`
|
||||
Denied bool `json:"denied"`
|
||||
}
|
||||
|
||||
type PermissionRequest struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Description string `json:"description"`
|
||||
Action string `json:"action"`
|
||||
@@ -38,22 +48,32 @@ type Service interface {
|
||||
Deny(permission PermissionRequest)
|
||||
Request(opts CreatePermissionRequest) bool
|
||||
AutoApproveSession(sessionID string)
|
||||
SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification]
|
||||
}
|
||||
|
||||
type permissionService struct {
|
||||
*pubsub.Broker[PermissionRequest]
|
||||
|
||||
notificationBroker *pubsub.Broker[PermissionNotification]
|
||||
workingDir string
|
||||
sessionPermissions []PermissionRequest
|
||||
sessionPermissionsMu sync.RWMutex
|
||||
pendingRequests sync.Map
|
||||
autoApproveSessions []string
|
||||
autoApproveSessions map[string]bool
|
||||
autoApproveSessionsMu sync.RWMutex
|
||||
skip bool
|
||||
allowedTools []string
|
||||
|
||||
// used to make sure we only process one request at a time
|
||||
requestMu sync.Mutex
|
||||
activeRequest *PermissionRequest
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
|
||||
s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
|
||||
ToolCallID: permission.ToolCallID,
|
||||
Granted: true,
|
||||
})
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
@@ -62,20 +82,41 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
|
||||
s.sessionPermissionsMu.Lock()
|
||||
s.sessionPermissions = append(s.sessionPermissions, permission)
|
||||
s.sessionPermissionsMu.Unlock()
|
||||
|
||||
if s.activeRequest != nil && s.activeRequest.ID == permission.ID {
|
||||
s.activeRequest = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) Grant(permission PermissionRequest) {
|
||||
s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
|
||||
ToolCallID: permission.ToolCallID,
|
||||
Granted: true,
|
||||
})
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- true
|
||||
}
|
||||
|
||||
if s.activeRequest != nil && s.activeRequest.ID == permission.ID {
|
||||
s.activeRequest = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) Deny(permission PermissionRequest) {
|
||||
s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
|
||||
ToolCallID: permission.ToolCallID,
|
||||
Granted: false,
|
||||
Denied: true,
|
||||
})
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
respCh.(chan bool) <- false
|
||||
}
|
||||
|
||||
if s.activeRequest != nil && s.activeRequest.ID == permission.ID {
|
||||
s.activeRequest = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
@@ -83,6 +124,13 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// tell the UI that a permission was requested
|
||||
s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
|
||||
ToolCallID: opts.ToolCallID,
|
||||
})
|
||||
s.requestMu.Lock()
|
||||
defer s.requestMu.Unlock()
|
||||
|
||||
// Check if the tool/action combination is in the allowlist
|
||||
commandKey := opts.ToolName + ":" + opts.Action
|
||||
if slices.Contains(s.allowedTools, commandKey) || slices.Contains(s.allowedTools, opts.ToolName) {
|
||||
@@ -90,7 +138,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
}
|
||||
|
||||
s.autoApproveSessionsMu.RLock()
|
||||
autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID)
|
||||
autoApprove := s.autoApproveSessions[opts.SessionID]
|
||||
s.autoApproveSessionsMu.RUnlock()
|
||||
|
||||
if autoApprove {
|
||||
@@ -101,10 +149,12 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
if dir == "." {
|
||||
dir = s.workingDir
|
||||
}
|
||||
slog.Info("Requesting permission", "session_id", opts.SessionID, "tool_name", opts.ToolName, "action", opts.Action, "path", dir)
|
||||
permission := PermissionRequest{
|
||||
ID: uuid.New().String(),
|
||||
Path: dir,
|
||||
SessionID: opts.SessionID,
|
||||
ToolCallID: opts.ToolCallID,
|
||||
ToolName: opts.ToolName,
|
||||
Description: opts.Description,
|
||||
Action: opts.Action,
|
||||
@@ -120,29 +170,45 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
}
|
||||
s.sessionPermissionsMu.RUnlock()
|
||||
|
||||
respCh := make(chan bool, 1)
|
||||
s.sessionPermissionsMu.RLock()
|
||||
for _, p := range s.sessionPermissions {
|
||||
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
|
||||
s.sessionPermissionsMu.RUnlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
s.sessionPermissionsMu.RUnlock()
|
||||
|
||||
s.activeRequest = &permission
|
||||
|
||||
respCh := make(chan bool, 1)
|
||||
s.pendingRequests.Store(permission.ID, respCh)
|
||||
defer s.pendingRequests.Delete(permission.ID)
|
||||
|
||||
// Publish the request
|
||||
s.Publish(pubsub.CreatedEvent, permission)
|
||||
|
||||
// Wait for the response indefinitely
|
||||
return <-respCh
|
||||
}
|
||||
|
||||
func (s *permissionService) AutoApproveSession(sessionID string) {
|
||||
s.autoApproveSessionsMu.Lock()
|
||||
s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
|
||||
s.autoApproveSessions[sessionID] = true
|
||||
s.autoApproveSessionsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *permissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification] {
|
||||
return s.notificationBroker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func NewPermissionService(workingDir string, skip bool, allowedTools []string) Service {
|
||||
return &permissionService{
|
||||
Broker: pubsub.NewBroker[PermissionRequest](),
|
||||
workingDir: workingDir,
|
||||
sessionPermissions: make([]PermissionRequest, 0),
|
||||
skip: skip,
|
||||
allowedTools: allowedTools,
|
||||
Broker: pubsub.NewBroker[PermissionRequest](),
|
||||
notificationBroker: pubsub.NewBroker[PermissionNotification](),
|
||||
workingDir: workingDir,
|
||||
sessionPermissions: make([]PermissionRequest, 0),
|
||||
autoApproveSessions: make(map[string]bool),
|
||||
skip: skip,
|
||||
allowedTools: allowedTools,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPermissionService_AllowedCommands(t *testing.T) {
|
||||
@@ -90,3 +93,159 @@ func TestPermissionService_SkipMode(t *testing.T) {
|
||||
t.Error("expected permission to be granted in skip mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionService_SequentialProperties(t *testing.T) {
|
||||
t.Run("Sequential permission requests with persistent grants", func(t *testing.T) {
|
||||
service := NewPermissionService("/tmp", false, []string{})
|
||||
|
||||
req1 := CreatePermissionRequest{
|
||||
SessionID: "session1",
|
||||
ToolName: "file_tool",
|
||||
Description: "Read file",
|
||||
Action: "read",
|
||||
Params: map[string]string{"file": "test.txt"},
|
||||
Path: "/tmp/test.txt",
|
||||
}
|
||||
|
||||
var result1 bool
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
events := service.Subscribe(t.Context())
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result1 = service.Request(req1)
|
||||
}()
|
||||
|
||||
var permissionReq PermissionRequest
|
||||
event := <-events
|
||||
|
||||
permissionReq = event.Payload
|
||||
service.GrantPersistent(permissionReq)
|
||||
|
||||
wg.Wait()
|
||||
assert.True(t, result1, "First request should be granted")
|
||||
|
||||
// Second identical request should be automatically approved due to persistent permission
|
||||
req2 := CreatePermissionRequest{
|
||||
SessionID: "session1",
|
||||
ToolName: "file_tool",
|
||||
Description: "Read file again",
|
||||
Action: "read",
|
||||
Params: map[string]string{"file": "test.txt"},
|
||||
Path: "/tmp/test.txt",
|
||||
}
|
||||
result2 := service.Request(req2)
|
||||
assert.True(t, result2, "Second request should be auto-approved")
|
||||
})
|
||||
t.Run("Sequential requests with temporary grants", func(t *testing.T) {
|
||||
service := NewPermissionService("/tmp", false, []string{})
|
||||
|
||||
req := CreatePermissionRequest{
|
||||
SessionID: "session2",
|
||||
ToolName: "file_tool",
|
||||
Description: "Write file",
|
||||
Action: "write",
|
||||
Params: map[string]string{"file": "test.txt"},
|
||||
Path: "/tmp/test.txt",
|
||||
}
|
||||
|
||||
events := service.Subscribe(t.Context())
|
||||
var result1 bool
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result1 = service.Request(req)
|
||||
}()
|
||||
|
||||
var permissionReq PermissionRequest
|
||||
event := <-events
|
||||
permissionReq = event.Payload
|
||||
|
||||
service.Grant(permissionReq)
|
||||
wg.Wait()
|
||||
assert.True(t, result1, "First request should be granted")
|
||||
|
||||
var result2 bool
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result2 = service.Request(req)
|
||||
}()
|
||||
|
||||
event = <-events
|
||||
permissionReq = event.Payload
|
||||
service.Deny(permissionReq)
|
||||
wg.Wait()
|
||||
assert.False(t, result2, "Second request should be denied")
|
||||
})
|
||||
t.Run("Concurrent requests with different outcomes", func(t *testing.T) {
|
||||
service := NewPermissionService("/tmp", false, []string{})
|
||||
|
||||
events := service.Subscribe(t.Context())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]bool, 0)
|
||||
|
||||
requests := []CreatePermissionRequest{
|
||||
{
|
||||
SessionID: "concurrent1",
|
||||
ToolName: "tool1",
|
||||
Action: "action1",
|
||||
Path: "/tmp/file1.txt",
|
||||
Description: "First concurrent request",
|
||||
},
|
||||
{
|
||||
SessionID: "concurrent2",
|
||||
ToolName: "tool2",
|
||||
Action: "action2",
|
||||
Path: "/tmp/file2.txt",
|
||||
Description: "Second concurrent request",
|
||||
},
|
||||
{
|
||||
SessionID: "concurrent3",
|
||||
ToolName: "tool3",
|
||||
Action: "action3",
|
||||
Path: "/tmp/file3.txt",
|
||||
Description: "Third concurrent request",
|
||||
},
|
||||
}
|
||||
|
||||
for i, req := range requests {
|
||||
wg.Add(1)
|
||||
go func(index int, request CreatePermissionRequest) {
|
||||
defer wg.Done()
|
||||
results = append(results, service.Request(request))
|
||||
}(i, req)
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
event := <-events
|
||||
switch event.Payload.ToolName {
|
||||
case "tool1":
|
||||
service.Grant(event.Payload)
|
||||
case "tool2":
|
||||
service.GrantPersistent(event.Payload)
|
||||
case "tool3":
|
||||
service.Deny(event.Payload)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
grantedCount := 0
|
||||
for _, result := range results {
|
||||
if result {
|
||||
grantedCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, grantedCount, "Should have 2 granted and 1 denied")
|
||||
secondReq := requests[1]
|
||||
secondReq.Description = "Repeat of second request"
|
||||
result := service.Request(secondReq)
|
||||
assert.True(t, result, "Repeated request should be auto-approved due to persistent permission")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/charmbracelet/crush/internal/app"
|
||||
"github.com/charmbracelet/crush/internal/llm/agent"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
"github.com/charmbracelet/crush/internal/session"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/chat/messages"
|
||||
@@ -85,6 +87,8 @@ func (m *messageListCmp) Init() tea.Cmd {
|
||||
// Update handles incoming messages and updates the component state.
|
||||
func (m *messageListCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case pubsub.Event[permission.PermissionNotification]:
|
||||
return m, m.handlePermissionRequest(msg.Payload)
|
||||
case SessionSelectedMsg:
|
||||
if msg.ID != m.session.ID {
|
||||
cmd := m.SetSession(msg)
|
||||
@@ -124,6 +128,20 @@ func (m *messageListCmp) View() string {
|
||||
)
|
||||
}
|
||||
|
||||
func (m *messageListCmp) handlePermissionRequest(permission permission.PermissionNotification) tea.Cmd {
|
||||
items := m.listCmp.Items()
|
||||
slog.Info("Handling permission request", "tool_call_id", permission.ToolCallID, "granted", permission.Granted)
|
||||
if toolCallIndex := m.findToolCallByID(items, permission.ToolCallID); toolCallIndex != NotFound {
|
||||
toolCall := items[toolCallIndex].(messages.ToolCallCmp)
|
||||
toolCall.SetPermissionRequested()
|
||||
if permission.Granted {
|
||||
toolCall.SetPermissionGranted()
|
||||
}
|
||||
m.listCmp.UpdateItem(toolCall.ID(), toolCall)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleChildSession handles messages from child sessions (agent tools).
|
||||
func (m *messageListCmp) handleChildSession(event pubsub.Event[message.Message]) tea.Cmd {
|
||||
var cmds []tea.Cmd
|
||||
@@ -158,6 +176,7 @@ func (m *messageListCmp) handleChildSession(event pubsub.Event[message.Message])
|
||||
nestedCall := messages.NewToolCallCmp(
|
||||
event.Payload.ID,
|
||||
tc,
|
||||
m.app.Permissions,
|
||||
messages.WithToolCallNested(true),
|
||||
)
|
||||
cmds = append(cmds, nestedCall.Init())
|
||||
@@ -199,7 +218,12 @@ func (m *messageListCmp) handleMessageEvent(event pubsub.Event[message.Message])
|
||||
if event.Payload.SessionID != m.session.ID {
|
||||
return m.handleChildSession(event)
|
||||
}
|
||||
return m.handleUpdateAssistantMessage(event.Payload)
|
||||
switch event.Payload.Role {
|
||||
case message.Assistant:
|
||||
return m.handleUpdateAssistantMessage(event.Payload)
|
||||
case message.Tool:
|
||||
return m.handleToolMessage(event.Payload)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -371,7 +395,7 @@ func (m *messageListCmp) updateOrAddToolCall(msg message.Message, tc message.Too
|
||||
}
|
||||
|
||||
// Add new tool call if not found
|
||||
return m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc))
|
||||
return m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc, m.app.Permissions))
|
||||
}
|
||||
|
||||
// handleNewAssistantMessage processes new assistant messages and their tool calls.
|
||||
@@ -390,7 +414,7 @@ func (m *messageListCmp) handleNewAssistantMessage(msg message.Message) tea.Cmd
|
||||
|
||||
// Add tool calls
|
||||
for _, tc := range msg.ToolCalls() {
|
||||
cmd := m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc))
|
||||
cmd := m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc, m.app.Permissions))
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
@@ -473,7 +497,7 @@ func (m *messageListCmp) convertAssistantMessage(msg message.Message, toolResult
|
||||
// Add tool calls with their results and status
|
||||
for _, tc := range msg.ToolCalls() {
|
||||
options := m.buildToolCallOptions(tc, msg, toolResultMap)
|
||||
uiMessages = append(uiMessages, messages.NewToolCallCmp(msg.ID, tc, options...))
|
||||
uiMessages = append(uiMessages, messages.NewToolCallCmp(msg.ID, tc, m.app.Permissions, options...))
|
||||
// If this tool call is the agent tool, fetch nested tool calls
|
||||
if tc.Name == agent.AgentToolName {
|
||||
nestedMessages, _ := m.app.Messages.List(context.Background(), tc.ID)
|
||||
|
||||
@@ -166,6 +166,7 @@ func init() {
|
||||
registry.register(tools.DownloadToolName, func() renderer { return downloadRenderer{} })
|
||||
registry.register(tools.ViewToolName, func() renderer { return viewRenderer{} })
|
||||
registry.register(tools.EditToolName, func() renderer { return editRenderer{} })
|
||||
registry.register(tools.MultiEditToolName, func() renderer { return multiEditRenderer{} })
|
||||
registry.register(tools.WriteToolName, func() renderer { return writeRenderer{} })
|
||||
registry.register(tools.FetchToolName, func() renderer { return fetchRenderer{} })
|
||||
registry.register(tools.GlobToolName, func() renderer { return globRenderer{} })
|
||||
@@ -316,6 +317,57 @@ func (er editRenderer) Render(v *toolCallCmp) string {
|
||||
})
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Multi-Edit renderer
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// multiEditRenderer handles multiple file edits with diff visualization
|
||||
type multiEditRenderer struct {
|
||||
baseRenderer
|
||||
}
|
||||
|
||||
// Render displays the multi-edited file with a formatted diff of changes
|
||||
func (mer multiEditRenderer) Render(v *toolCallCmp) string {
|
||||
t := styles.CurrentTheme()
|
||||
var params tools.MultiEditParams
|
||||
var args []string
|
||||
if err := mer.unmarshalParams(v.call.Input, ¶ms); err == nil {
|
||||
file := fsext.PrettyPath(params.FilePath)
|
||||
editsCount := len(params.Edits)
|
||||
args = newParamBuilder().
|
||||
addMain(file).
|
||||
addKeyValue("edits", fmt.Sprintf("%d", editsCount)).
|
||||
build()
|
||||
}
|
||||
|
||||
return mer.renderWithParams(v, "Multi-Edit", args, func() string {
|
||||
var meta tools.MultiEditResponseMetadata
|
||||
if err := mer.unmarshalParams(v.result.Metadata, &meta); err != nil {
|
||||
return renderPlainContent(v, v.result.Content)
|
||||
}
|
||||
|
||||
formatter := core.DiffFormatter().
|
||||
Before(fsext.PrettyPath(params.FilePath), meta.OldContent).
|
||||
After(fsext.PrettyPath(params.FilePath), meta.NewContent).
|
||||
Width(v.textWidth() - 2) // -2 for padding
|
||||
if v.textWidth() > 120 {
|
||||
formatter = formatter.Split()
|
||||
}
|
||||
// add a message to the bottom if the content was truncated
|
||||
formatted := formatter.String()
|
||||
if lipgloss.Height(formatted) > responseContextHeight {
|
||||
contentLines := strings.Split(formatted, "\n")
|
||||
truncateMessage := t.S().Muted.
|
||||
Background(t.BgBaseLighter).
|
||||
PaddingLeft(2).
|
||||
Width(v.textWidth() - 4).
|
||||
Render(fmt.Sprintf("… (%d lines)", len(contentLines)-responseContextHeight))
|
||||
formatted = strings.Join(contentLines[:responseContextHeight], "\n") + "\n" + truncateMessage
|
||||
}
|
||||
return formatted
|
||||
})
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Write renderer
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -672,7 +724,11 @@ func earlyState(header string, v *toolCallCmp) (string, bool) {
|
||||
case v.cancelled:
|
||||
message = t.S().Base.Foreground(t.FgSubtle).Render("Canceled.")
|
||||
case v.result.ToolCallID == "":
|
||||
message = t.S().Base.Foreground(t.FgSubtle).Render("Waiting for tool to start...")
|
||||
if v.permissionRequested && !v.permissionGranted {
|
||||
message = t.S().Base.Foreground(t.FgSubtle).Render("Requesting for permission...")
|
||||
} else {
|
||||
message = t.S().Base.Foreground(t.FgSubtle).Render("Waiting for tool response...")
|
||||
}
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
@@ -799,6 +855,8 @@ func prettifyToolName(name string) string {
|
||||
return "Download"
|
||||
case tools.EditToolName:
|
||||
return "Edit"
|
||||
case tools.MultiEditToolName:
|
||||
return "Multi-Edit"
|
||||
case tools.FetchToolName:
|
||||
return "Fetch"
|
||||
case tools.GlobToolName:
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/anim"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core/layout"
|
||||
"github.com/charmbracelet/crush/internal/tui/styles"
|
||||
@@ -30,6 +31,8 @@ type ToolCallCmp interface {
|
||||
SetNestedToolCalls([]ToolCallCmp) // Set nested tool calls
|
||||
SetIsNested(bool) // Set whether this tool call is nested
|
||||
ID() string
|
||||
SetPermissionRequested() // Mark permission request
|
||||
SetPermissionGranted() // Mark permission granted
|
||||
}
|
||||
|
||||
// toolCallCmp implements the ToolCallCmp interface for displaying tool calls.
|
||||
@@ -40,10 +43,12 @@ type toolCallCmp struct {
|
||||
isNested bool // Whether this tool call is nested within another
|
||||
|
||||
// Tool call data and state
|
||||
parentMessageID string // ID of the message that initiated this tool call
|
||||
call message.ToolCall // The tool call being executed
|
||||
result message.ToolResult // The result of the tool execution
|
||||
cancelled bool // Whether the tool call was cancelled
|
||||
parentMessageID string // ID of the message that initiated this tool call
|
||||
call message.ToolCall // The tool call being executed
|
||||
result message.ToolResult // The result of the tool execution
|
||||
cancelled bool // Whether the tool call was cancelled
|
||||
permissionRequested bool
|
||||
permissionGranted bool
|
||||
|
||||
// Animation state for pending tool calls
|
||||
spinning bool // Whether to show loading animation
|
||||
@@ -81,9 +86,21 @@ func WithToolCallNestedCalls(calls []ToolCallCmp) ToolCallOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolPermissionRequested() ToolCallOption {
|
||||
return func(m *toolCallCmp) {
|
||||
m.permissionRequested = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolPermissionGranted() ToolCallOption {
|
||||
return func(m *toolCallCmp) {
|
||||
m.permissionGranted = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewToolCallCmp creates a new tool call component with the given parent message ID,
|
||||
// tool call, and optional configuration
|
||||
func NewToolCallCmp(parentMessageID string, tc message.ToolCall, opts ...ToolCallOption) ToolCallCmp {
|
||||
func NewToolCallCmp(parentMessageID string, tc message.ToolCall, permissions permission.Service, opts ...ToolCallOption) ToolCallCmp {
|
||||
m := &toolCallCmp{
|
||||
call: tc,
|
||||
parentMessageID: parentMessageID,
|
||||
@@ -316,3 +333,13 @@ func (m *toolCallCmp) Spinning() bool {
|
||||
func (m *toolCallCmp) ID() string {
|
||||
return m.call.ID
|
||||
}
|
||||
|
||||
// SetPermissionRequested marks that a permission request was made for this tool call
|
||||
func (m *toolCallCmp) SetPermissionRequested() {
|
||||
m.permissionRequested = true
|
||||
}
|
||||
|
||||
// SetPermissionGranted marks that permission was granted for this tool call
|
||||
func (m *toolCallCmp) SetPermissionGranted() {
|
||||
m.permissionGranted = true
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (p *permissionDialogCmp) Init() tea.Cmd {
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) supportsDiffView() bool {
|
||||
return p.permission.ToolName == tools.EditToolName || p.permission.ToolName == tools.WriteToolName
|
||||
return p.permission.ToolName == tools.EditToolName || p.permission.ToolName == tools.WriteToolName || p.permission.ToolName == tools.MultiEditToolName
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
@@ -305,6 +305,20 @@ func (p *permissionDialogCmp) renderHeader() string {
|
||||
),
|
||||
baseStyle.Render(strings.Repeat(" ", p.width)),
|
||||
)
|
||||
case tools.MultiEditToolName:
|
||||
params := p.permission.Params.(tools.MultiEditPermissionsParams)
|
||||
fileKey := t.S().Muted.Render("File")
|
||||
filePath := t.S().Text.
|
||||
Width(p.width - lipgloss.Width(fileKey)).
|
||||
Render(fmt.Sprintf(" %s", fsext.PrettyPath(params.FilePath)))
|
||||
headerParts = append(headerParts,
|
||||
lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
fileKey,
|
||||
filePath,
|
||||
),
|
||||
baseStyle.Render(strings.Repeat(" ", p.width)),
|
||||
)
|
||||
case tools.FetchToolName:
|
||||
headerParts = append(headerParts, t.S().Muted.Width(p.width).Bold(true).Render("URL"))
|
||||
}
|
||||
@@ -329,6 +343,8 @@ func (p *permissionDialogCmp) getOrGenerateContent() string {
|
||||
content = p.generateEditContent()
|
||||
case tools.WriteToolName:
|
||||
content = p.generateWriteContent()
|
||||
case tools.MultiEditToolName:
|
||||
content = p.generateMultiEditContent()
|
||||
case tools.FetchToolName:
|
||||
content = p.generateFetchContent()
|
||||
default:
|
||||
@@ -435,6 +451,28 @@ func (p *permissionDialogCmp) generateDownloadContent() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) generateMultiEditContent() string {
|
||||
if pr, ok := p.permission.Params.(tools.MultiEditPermissionsParams); ok {
|
||||
// Use the cache for diff rendering
|
||||
formatter := core.DiffFormatter().
|
||||
Before(fsext.PrettyPath(pr.FilePath), pr.OldContent).
|
||||
After(fsext.PrettyPath(pr.FilePath), pr.NewContent).
|
||||
Height(p.contentViewPort.Height()).
|
||||
Width(p.contentViewPort.Width()).
|
||||
XOffset(p.diffXOffset).
|
||||
YOffset(p.diffYOffset)
|
||||
if p.useDiffSplitMode() {
|
||||
formatter = formatter.Split()
|
||||
} else {
|
||||
formatter = formatter.Unified()
|
||||
}
|
||||
|
||||
diff := formatter.String()
|
||||
return diff
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) generateFetchContent() string {
|
||||
t := styles.CurrentTheme()
|
||||
baseStyle := t.S().Base.Background(t.BgSubtle)
|
||||
@@ -579,6 +617,9 @@ func (p *permissionDialogCmp) SetSize() tea.Cmd {
|
||||
case tools.WriteToolName:
|
||||
p.width = int(float64(p.wWidth) * 0.8)
|
||||
p.height = int(float64(p.wHeight) * 0.8)
|
||||
case tools.MultiEditToolName:
|
||||
p.width = int(float64(p.wWidth) * 0.8)
|
||||
p.height = int(float64(p.wHeight) * 0.8)
|
||||
case tools.FetchToolName:
|
||||
p.width = int(float64(p.wWidth) * 0.8)
|
||||
p.height = int(float64(p.wHeight) * 0.3)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
"github.com/charmbracelet/crush/internal/session"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/anim"
|
||||
@@ -251,6 +252,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
p.sidebar = u.(sidebar.Sidebar)
|
||||
cmds = append(cmds, cmd)
|
||||
return p, tea.Batch(cmds...)
|
||||
case pubsub.Event[permission.PermissionNotification]:
|
||||
u, cmd := p.chat.Update(msg)
|
||||
p.chat = u.(chat.MessageListCmp)
|
||||
cmds = append(cmds, cmd)
|
||||
return p, tea.Batch(cmds...)
|
||||
|
||||
case commands.CommandRunCustomMsg:
|
||||
if p.app.CoderAgent.IsBusy() {
|
||||
|
||||
@@ -205,6 +205,11 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
Model: filepicker.NewFilePickerCmp(a.app.Config().WorkingDir()),
|
||||
})
|
||||
// Permissions
|
||||
case pubsub.Event[permission.PermissionNotification]:
|
||||
// forward to page
|
||||
updated, cmd := a.pages[a.currentPage].Update(msg)
|
||||
a.pages[a.currentPage] = updated.(util.Model)
|
||||
return a, cmd
|
||||
case pubsub.Event[permission.PermissionRequest]:
|
||||
return a, util.CmdHandler(dialogs.OpenDialogMsg{
|
||||
Model: permissions.NewPermissionDialogCmp(msg.Payload),
|
||||
|
||||
Reference in New Issue
Block a user