fix: do not init MCP client on every tool request (#459)

* fix: do not init mcp client on every call

Right now it inits each mcp client multiple times, one when discovering tools at startup, and then every time we call any tools.

This makes it so we reuse the client from startup

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* wip

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* refactor: even better approach

* fix: unused param

* refactor: more improvements

* fix: if list tools fails, remove client

* fix: improve slice

* chore: smaller changes

---------

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
This commit is contained in:
Carlos Alexandro Becker
2025-08-01 15:22:47 -03:00
committed by GitHub
parent 70f479e5a2
commit cd3ef8dbd4
3 changed files with 96 additions and 107 deletions

View File

@@ -268,6 +268,10 @@ func (app *App) InitCoderAgent() error {
slog.Error("Failed to create coder agent", "err", err)
return err
}
// Add MCP client cleanup to shutdown process
app.cleanupFuncs = append(app.cleanupFuncs, agent.CloseMCPClients)
setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
return nil
}

View File

@@ -189,7 +189,9 @@ func NewAgent(
tools.NewWriteTool(lspClients, permissions, history, cwd),
}
mcpTools := GetMCPTools(ctx, permissions, cfg)
mcpToolsOnce.Do(func() {
mcpTools = doGetMCPTools(ctx, permissions, cfg)
})
allTools = append(allTools, mcpTools...)
if len(lspClients) > 0 {

View File

@@ -11,34 +11,28 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/version"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/version"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)
var (
mcpToolsOnce sync.Once
mcpTools []tools.BaseTool
mcpClients = csync.NewMap[string, *client.Client]()
)
type McpTool struct {
mcpName string
tool mcp.Tool
client MCPClient
mcpConfig config.MCPConfig
permissions permission.Service
workingDir string
}
type MCPClient interface {
Initialize(
ctx context.Context,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, error)
ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
Close() error
}
func (b *McpTool) Name() string {
return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
}
@@ -56,27 +50,21 @@ func (b *McpTool) Info() tools.ToolInfo {
}
}
func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "Crush",
Version: version.Version,
}
_, err := c.Initialize(ctx, initRequest)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
toolRequest := mcp.CallToolRequest{}
toolRequest.Params.Name = toolName
func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
var args map[string]any
if err = json.Unmarshal([]byte(input), &args); err != nil {
if err := json.Unmarshal([]byte(input), &args); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
toolRequest.Params.Arguments = args
result, err := c.CallTool(ctx, toolRequest)
c, ok := mcpClients.Get(name)
if !ok {
return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
}
result, err := c.CallTool(ctx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: toolName,
Arguments: args,
},
})
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
@@ -114,56 +102,34 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, permission.ErrorPermissionDenied
}
return runTool(ctx, b.client, b.tool.Name, params.Input)
return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
}
func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
return &McpTool{
mcpName: name,
client: c,
tool: tool,
mcpConfig: mcpConfig,
permissions: permissions,
workingDir: workingDir,
}
}
func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
var stdioTools []tools.BaseTool
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "Crush",
Version: version.Version,
}
_, err := c.Initialize(ctx, initRequest)
if err != nil {
slog.Error("error initializing mcp client", "error", err)
return stdioTools
}
toolsRequest := mcp.ListToolsRequest{}
tools, err := c.ListTools(ctx, toolsRequest)
func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
slog.Error("error listing tools", "error", err)
return stdioTools
c.Close()
mcpClients.Del(name)
return nil
}
for _, t := range tools.Tools {
stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir))
mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
for _, tool := range result.Tools {
mcpTools = append(mcpTools, &McpTool{
mcpName: name,
tool: tool,
permissions: permissions,
workingDir: workingDir,
})
}
return stdioTools
return mcpTools
}
var (
mcpToolsOnce sync.Once
mcpTools []tools.BaseTool
)
func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
mcpToolsOnce.Do(func() {
mcpTools = doGetMCPTools(ctx, permissions, cfg)
})
return mcpTools
// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
func CloseMCPClients() {
for c := range mcpClients.Seq() {
_ = c.Close()
}
}
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
@@ -177,42 +143,59 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
wg.Add(1)
go func(name string, m config.MCPConfig) {
defer wg.Done()
switch m.Type {
case config.MCPStdio:
c, err := client.NewStdioMCPClient(
m.Command,
m.ResolvedEnv(),
m.Args...,
)
if err != nil {
slog.Error("error creating mcp client", "error", err)
return
}
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
case config.MCPHttp:
c, err := client.NewStreamableHttpClient(
m.URL,
transport.WithHTTPHeaders(m.ResolvedHeaders()),
)
if err != nil {
slog.Error("error creating mcp client", "error", err)
return
}
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
case config.MCPSse:
c, err := client.NewSSEMCPClient(
m.URL,
client.WithHeaders(m.ResolvedHeaders()),
)
if err != nil {
slog.Error("error creating mcp client", "error", err)
return
}
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
c, err := doGetClient(m)
if err != nil {
slog.Error("error creating mcp client", "error", err)
return
}
if err := doInitClient(ctx, name, c); err != nil {
slog.Error("error initializing mcp client", "error", err)
return
}
result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
}(name, m)
}
wg.Wait()
return slices.Collect(result.Seq())
}
func doInitClient(ctx context.Context, name string, c *client.Client) error {
initRequest := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "Crush",
Version: version.Version,
},
},
}
if _, err := c.Initialize(ctx, initRequest); err != nil {
c.Close()
return err
}
mcpClients.Set(name, c)
return nil
}
func doGetClient(m config.MCPConfig) (*client.Client, error) {
switch m.Type {
case config.MCPStdio:
return client.NewStdioMCPClient(
m.Command,
m.ResolvedEnv(),
m.Args...,
)
case config.MCPHttp:
return client.NewStreamableHttpClient(
m.URL,
transport.WithHTTPHeaders(m.ResolvedHeaders()),
)
case config.MCPSse:
return client.NewSSEMCPClient(
m.URL,
client.WithHeaders(m.ResolvedHeaders()),
)
default:
return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
}
}