Files
crush-code-agent-ide/internal/llm/agent/mcp-tools.go
Carlos Alexandro Becker cd3ef8dbd4 fix: do not init MCP client on every tool request (#459)
* fix: do not init mcp client on every call

Right now it inits each mcp client multiple times, one when discovering tools at startup, and then every time we call any tools.

This makes it so we reuse the client from startup

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* wip

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* refactor: even better approach

* fix: unused param

* refactor: more improvements

* fix: if list tools fails, remove client

* fix: improve slice

* chore: smaller changes

---------

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
2025-08-01 15:22:47 -03:00

202 lines
5.1 KiB
Go

package agent
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"slices"
"sync"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/version"
"github.com/charmbracelet/crush/internal/permission"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)
var (
mcpToolsOnce sync.Once
mcpTools []tools.BaseTool
mcpClients = csync.NewMap[string, *client.Client]()
)
type McpTool struct {
mcpName string
tool mcp.Tool
permissions permission.Service
workingDir string
}
func (b *McpTool) Name() string {
return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
}
func (b *McpTool) Info() tools.ToolInfo {
required := b.tool.InputSchema.Required
if required == nil {
required = make([]string, 0)
}
return tools.ToolInfo{
Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
Description: b.tool.Description,
Parameters: b.tool.InputSchema.Properties,
Required: required,
}
}
func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
var args map[string]any
if err := json.Unmarshal([]byte(input), &args); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
c, ok := mcpClients.Get(name)
if !ok {
return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
}
result, err := c.CallTool(ctx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: toolName,
Arguments: args,
},
})
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
output := ""
for _, v := range result.Content {
if v, ok := v.(mcp.TextContent); ok {
output = v.Text
} else {
output = fmt.Sprintf("%v", v)
}
}
return tools.NewTextResponse(output), nil
}
func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
sessionID, messageID := tools.GetContextValues(ctx)
if sessionID == "" || messageID == "" {
return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
}
permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
p := b.permissions.Request(
permission.CreatePermissionRequest{
SessionID: sessionID,
ToolCallID: params.ID,
Path: b.workingDir,
ToolName: b.Info().Name,
Action: "execute",
Description: permissionDescription,
Params: params.Input,
},
)
if !p {
return tools.ToolResponse{}, permission.ErrorPermissionDenied
}
return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
}
func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
slog.Error("error listing tools", "error", err)
c.Close()
mcpClients.Del(name)
return nil
}
mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
for _, tool := range result.Tools {
mcpTools = append(mcpTools, &McpTool{
mcpName: name,
tool: tool,
permissions: permissions,
workingDir: workingDir,
})
}
return mcpTools
}
// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
func CloseMCPClients() {
for c := range mcpClients.Seq() {
_ = c.Close()
}
}
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
var wg sync.WaitGroup
result := csync.NewSlice[tools.BaseTool]()
for name, m := range cfg.MCP {
if m.Disabled {
slog.Debug("skipping disabled mcp", "name", name)
continue
}
wg.Add(1)
go func(name string, m config.MCPConfig) {
defer wg.Done()
c, err := doGetClient(m)
if err != nil {
slog.Error("error creating mcp client", "error", err)
return
}
if err := doInitClient(ctx, name, c); err != nil {
slog.Error("error initializing mcp client", "error", err)
return
}
result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
}(name, m)
}
wg.Wait()
return slices.Collect(result.Seq())
}
func doInitClient(ctx context.Context, name string, c *client.Client) error {
initRequest := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "Crush",
Version: version.Version,
},
},
}
if _, err := c.Initialize(ctx, initRequest); err != nil {
c.Close()
return err
}
mcpClients.Set(name, c)
return nil
}
func doGetClient(m config.MCPConfig) (*client.Client, error) {
switch m.Type {
case config.MCPStdio:
return client.NewStdioMCPClient(
m.Command,
m.ResolvedEnv(),
m.Args...,
)
case config.MCPHttp:
return client.NewStreamableHttpClient(
m.URL,
transport.WithHTTPHeaders(m.ResolvedHeaders()),
)
case config.MCPSse:
return client.NewSSEMCPClient(
m.URL,
client.WithHeaders(m.ResolvedHeaders()),
)
default:
return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
}
}