mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
feat: add a download tool
This commit is contained in:
@@ -102,6 +102,7 @@ func NewAgent(
|
||||
cwd := cfg.WorkingDir()
|
||||
allTools := []tools.BaseTool{
|
||||
tools.NewBashTool(permissions, cwd),
|
||||
tools.NewDownloadTool(permissions, cwd),
|
||||
tools.NewEditTool(lspClients, permissions, history, cwd),
|
||||
tools.NewFetchTool(permissions, cwd),
|
||||
tools.NewGlobTool(cwd),
|
||||
|
||||
223
internal/llm/tools/download.go
Normal file
223
internal/llm/tools/download.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
)
|
||||
|
||||
type DownloadParams struct {
|
||||
URL string `json:"url"`
|
||||
FilePath string `json:"file_path"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type DownloadPermissionsParams struct {
|
||||
URL string `json:"url"`
|
||||
FilePath string `json:"file_path"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type downloadTool struct {
|
||||
client *http.Client
|
||||
permissions permission.Service
|
||||
workingDir string
|
||||
}
|
||||
|
||||
const (
|
||||
DownloadToolName = "download"
|
||||
downloadToolDescription = `Downloads binary data from a URL and saves it to a local file.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to download files, images, or other binary data from URLs
|
||||
- Helpful for downloading assets, documents, or any file type
|
||||
- Useful for saving remote content locally for processing or storage
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the URL to download from
|
||||
- Specify the local file path where the content should be saved
|
||||
- Optionally set a timeout for the request
|
||||
|
||||
FEATURES:
|
||||
- Downloads any file type (binary or text)
|
||||
- Automatically creates parent directories if they don't exist
|
||||
- Handles large files efficiently with streaming
|
||||
- Sets reasonable timeouts to prevent hanging
|
||||
- Validates input parameters before making requests
|
||||
|
||||
LIMITATIONS:
|
||||
- Maximum file size is 100MB
|
||||
- Only supports HTTP and HTTPS protocols
|
||||
- Cannot handle authentication or cookies
|
||||
- Some websites may block automated requests
|
||||
- Will overwrite existing files without warning
|
||||
|
||||
TIPS:
|
||||
- Use absolute paths or paths relative to the working directory
|
||||
- Set appropriate timeouts for large files or slow connections`
|
||||
)
|
||||
|
||||
func NewDownloadTool(permissions permission.Service, workingDir string) BaseTool {
|
||||
return &downloadTool{
|
||||
client: &http.Client{
|
||||
Timeout: 5 * time.Minute, // Default 5 minute timeout for downloads
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
},
|
||||
permissions: permissions,
|
||||
workingDir: workingDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *downloadTool) Name() string {
|
||||
return DownloadToolName
|
||||
}
|
||||
|
||||
func (t *downloadTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: DownloadToolName,
|
||||
Description: downloadToolDescription,
|
||||
Parameters: map[string]any{
|
||||
"url": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The URL to download from",
|
||||
},
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The local file path where the downloaded content should be saved",
|
||||
},
|
||||
"timeout": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Optional timeout in seconds (max 600)",
|
||||
},
|
||||
},
|
||||
Required: []string{"url", "file_path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *downloadTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params DownloadParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("Failed to parse download parameters: " + err.Error()), nil
|
||||
}
|
||||
|
||||
if params.URL == "" {
|
||||
return NewTextErrorResponse("URL parameter is required"), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return NewTextErrorResponse("file_path parameter is required"), nil
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
|
||||
return NewTextErrorResponse("URL must start with http:// or https://"), nil
|
||||
}
|
||||
|
||||
// Convert relative path to absolute path
|
||||
var filePath string
|
||||
if filepath.IsAbs(params.FilePath) {
|
||||
filePath = params.FilePath
|
||||
} else {
|
||||
filePath = filepath.Join(t.workingDir, params.FilePath)
|
||||
}
|
||||
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for downloading files")
|
||||
}
|
||||
|
||||
p := t.permissions.Request(
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: filePath,
|
||||
ToolName: DownloadToolName,
|
||||
Action: "download",
|
||||
Description: fmt.Sprintf("Download file from URL: %s to %s", params.URL, filePath),
|
||||
Params: DownloadPermissionsParams(params),
|
||||
},
|
||||
)
|
||||
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
// Handle timeout with context
|
||||
requestCtx := ctx
|
||||
if params.Timeout > 0 {
|
||||
maxTimeout := 600 // 10 minutes
|
||||
if params.Timeout > maxTimeout {
|
||||
params.Timeout = maxTimeout
|
||||
}
|
||||
var cancel context.CancelFunc
|
||||
requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "crush/1.0")
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to download from URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
|
||||
}
|
||||
|
||||
// Check content length if available
|
||||
maxSize := int64(100 * 1024 * 1024) // 100MB
|
||||
if resp.ContentLength > maxSize {
|
||||
return NewTextErrorResponse(fmt.Sprintf("File too large: %d bytes (max %d bytes)", resp.ContentLength, maxSize)), nil
|
||||
}
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
|
||||
}
|
||||
|
||||
// Create the output file
|
||||
outFile, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// Copy data with size limit
|
||||
limitedReader := io.LimitReader(resp.Body, maxSize)
|
||||
bytesWritten, err := io.Copy(outFile, limitedReader)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Check if we hit the size limit
|
||||
if bytesWritten == maxSize {
|
||||
// Clean up the file since it might be incomplete
|
||||
os.Remove(filePath)
|
||||
return NewTextErrorResponse(fmt.Sprintf("File too large: exceeded %d bytes limit", maxSize)), nil
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
responseMsg := fmt.Sprintf("Successfully downloaded %d bytes to %s", bytesWritten, filePath)
|
||||
if contentType != "" {
|
||||
responseMsg += fmt.Sprintf(" (Content-Type: %s)", contentType)
|
||||
}
|
||||
|
||||
return NewTextResponse(responseMsg), nil
|
||||
}
|
||||
@@ -162,6 +162,7 @@ func (br baseRenderer) renderError(v *toolCallCmp, message string) string {
|
||||
// Register tool renderers
|
||||
func init() {
|
||||
registry.register(tools.BashToolName, func() renderer { return bashRenderer{} })
|
||||
registry.register(tools.DownloadToolName, func() renderer { return downloadRenderer{} })
|
||||
registry.register(tools.ViewToolName, func() renderer { return viewRenderer{} })
|
||||
registry.register(tools.EditToolName, func() renderer { return editRenderer{} })
|
||||
registry.register(tools.WriteToolName, func() renderer { return writeRenderer{} })
|
||||
@@ -376,6 +377,32 @@ func formatTimeout(timeout int) string {
|
||||
return (time.Duration(timeout) * time.Second).String()
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Download renderer
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// downloadRenderer handles file downloading with URL and file path display
|
||||
type downloadRenderer struct {
|
||||
baseRenderer
|
||||
}
|
||||
|
||||
// Render displays the download URL and destination file path with timeout parameter
|
||||
func (dr downloadRenderer) Render(v *toolCallCmp) string {
|
||||
var params tools.DownloadParams
|
||||
var args []string
|
||||
if err := dr.unmarshalParams(v.call.Input, ¶ms); err == nil {
|
||||
args = newParamBuilder().
|
||||
addMain(params.URL).
|
||||
addKeyValue("file_path", fsext.PrettyPath(params.FilePath)).
|
||||
addKeyValue("timeout", formatTimeout(params.Timeout)).
|
||||
build()
|
||||
}
|
||||
|
||||
return dr.renderWithParams(v, "Download", args, func() string {
|
||||
return renderPlainContent(v, v.result.Content)
|
||||
})
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Glob renderer
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -757,6 +784,8 @@ func prettifyToolName(name string) string {
|
||||
return "Agent"
|
||||
case tools.BashToolName:
|
||||
return "Bash"
|
||||
case tools.DownloadToolName:
|
||||
return "Download"
|
||||
case tools.EditToolName:
|
||||
return "Edit"
|
||||
case tools.FetchToolName:
|
||||
|
||||
@@ -252,6 +252,30 @@ func (p *permissionDialogCmp) renderHeader() string {
|
||||
switch p.permission.ToolName {
|
||||
case tools.BashToolName:
|
||||
headerParts = append(headerParts, t.S().Muted.Width(p.width).Render("Command"))
|
||||
case tools.DownloadToolName:
|
||||
params := p.permission.Params.(tools.DownloadPermissionsParams)
|
||||
urlKey := t.S().Muted.Render("URL")
|
||||
urlValue := t.S().Text.
|
||||
Width(p.width - lipgloss.Width(urlKey)).
|
||||
Render(fmt.Sprintf(" %s", params.URL))
|
||||
fileKey := t.S().Muted.Render("File")
|
||||
filePath := t.S().Text.
|
||||
Width(p.width - lipgloss.Width(fileKey)).
|
||||
Render(fmt.Sprintf(" %s", fsext.PrettyPath(params.FilePath)))
|
||||
headerParts = append(headerParts,
|
||||
lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
urlKey,
|
||||
urlValue,
|
||||
),
|
||||
baseStyle.Render(strings.Repeat(" ", p.width)),
|
||||
lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
fileKey,
|
||||
filePath,
|
||||
),
|
||||
baseStyle.Render(strings.Repeat(" ", p.width)),
|
||||
)
|
||||
case tools.EditToolName:
|
||||
params := p.permission.Params.(tools.EditPermissionsParams)
|
||||
fileKey := t.S().Muted.Render("File")
|
||||
@@ -299,6 +323,8 @@ func (p *permissionDialogCmp) getOrGenerateContent() string {
|
||||
switch p.permission.ToolName {
|
||||
case tools.BashToolName:
|
||||
content = p.generateBashContent()
|
||||
case tools.DownloadToolName:
|
||||
content = p.generateDownloadContent()
|
||||
case tools.EditToolName:
|
||||
content = p.generateEditContent()
|
||||
case tools.WriteToolName:
|
||||
@@ -391,6 +417,24 @@ func (p *permissionDialogCmp) generateWriteContent() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) generateDownloadContent() string {
|
||||
t := styles.CurrentTheme()
|
||||
baseStyle := t.S().Base.Background(t.BgSubtle)
|
||||
if pr, ok := p.permission.Params.(tools.DownloadPermissionsParams); ok {
|
||||
content := fmt.Sprintf("URL: %s\nFile: %s", pr.URL, fsext.PrettyPath(pr.FilePath))
|
||||
if pr.Timeout > 0 {
|
||||
content += fmt.Sprintf("\nTimeout: %ds", pr.Timeout)
|
||||
}
|
||||
|
||||
finalContent := baseStyle.
|
||||
Padding(1, 2).
|
||||
Width(p.contentViewPort.Width()).
|
||||
Render(content)
|
||||
return finalContent
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) generateFetchContent() string {
|
||||
t := styles.CurrentTheme()
|
||||
baseStyle := t.S().Base.Background(t.BgSubtle)
|
||||
@@ -526,6 +570,9 @@ func (p *permissionDialogCmp) SetSize() tea.Cmd {
|
||||
case tools.BashToolName:
|
||||
p.width = int(float64(p.wWidth) * 0.8)
|
||||
p.height = int(float64(p.wHeight) * 0.3)
|
||||
case tools.DownloadToolName:
|
||||
p.width = int(float64(p.wWidth) * 0.8)
|
||||
p.height = int(float64(p.wHeight) * 0.4)
|
||||
case tools.EditToolName:
|
||||
p.width = int(float64(p.wWidth) * 0.8)
|
||||
p.height = int(float64(p.wHeight) * 0.8)
|
||||
|
||||
Reference in New Issue
Block a user