diff --git a/internal/app/app.go b/internal/app/app.go index ca48d3e4..849e4fcc 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -268,6 +268,10 @@ func (app *App) InitCoderAgent() error { slog.Error("Failed to create coder agent", "err", err) return err } + + // Add MCP client cleanup to shutdown process + app.cleanupFuncs = append(app.cleanupFuncs, agent.CloseMCPClients) + setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events) return nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index df2f0adf..80a7095e 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -189,7 +189,9 @@ func NewAgent( tools.NewWriteTool(lspClients, permissions, history, cwd), } - mcpTools := GetMCPTools(ctx, permissions, cfg) + mcpToolsOnce.Do(func() { + mcpTools = doGetMCPTools(ctx, permissions, cfg) + }) allTools = append(allTools, mcpTools...) if len(lspClients) > 0 { diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 77149d85..cc43a7fc 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -11,34 +11,28 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/version" "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/version" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) +var ( + mcpToolsOnce sync.Once + mcpTools []tools.BaseTool + mcpClients = csync.NewMap[string, *client.Client]() +) + type McpTool struct { mcpName string tool mcp.Tool - client MCPClient - mcpConfig config.MCPConfig permissions permission.Service workingDir string } -type MCPClient interface { - Initialize( - ctx context.Context, - request mcp.InitializeRequest, - ) (*mcp.InitializeResult, error) - ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error) - CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - Close() error -} - func (b *McpTool) Name() string { return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name) } @@ -56,27 +50,21 @@ func (b *McpTool) Info() tools.ToolInfo { } } -func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) { - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "Crush", - Version: version.Version, - } - - _, err := c.Initialize(ctx, initRequest) - if err != nil { - return tools.NewTextErrorResponse(err.Error()), nil - } - - toolRequest := mcp.CallToolRequest{} - toolRequest.Params.Name = toolName +func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) { var args map[string]any - if err = json.Unmarshal([]byte(input), &args); err != nil { + if err := json.Unmarshal([]byte(input), &args); err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil } - toolRequest.Params.Arguments = args - result, err := c.CallTool(ctx, toolRequest) + c, ok := mcpClients.Get(name) + if !ok { + return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil + } + result, err := c.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: args, + }, + }) if err != nil { return tools.NewTextErrorResponse(err.Error()), nil } @@ -114,56 +102,34 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, permission.ErrorPermissionDenied } - return runTool(ctx, b.client, b.tool.Name, params.Input) + return runTool(ctx, b.mcpName, b.tool.Name, params.Input) } -func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool { - return &McpTool{ - mcpName: name, - client: c, - tool: tool, - mcpConfig: mcpConfig, - permissions: permissions, - workingDir: workingDir, - } -} - -func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool { - var stdioTools []tools.BaseTool - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "Crush", - Version: version.Version, - } - - _, err := c.Initialize(ctx, initRequest) - if err != nil { - slog.Error("error initializing mcp client", "error", err) - return stdioTools - } - toolsRequest := mcp.ListToolsRequest{} - tools, err := c.ListTools(ctx, toolsRequest) +func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool { + result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { slog.Error("error listing tools", "error", err) - return stdioTools + c.Close() + mcpClients.Del(name) + return nil } - for _, t := range tools.Tools { - stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir)) + mcpTools := make([]tools.BaseTool, 0, len(result.Tools)) + for _, tool := range result.Tools { + mcpTools = append(mcpTools, &McpTool{ + mcpName: name, + tool: tool, + permissions: permissions, + workingDir: workingDir, + }) } - return stdioTools + return mcpTools } -var ( - mcpToolsOnce sync.Once - mcpTools []tools.BaseTool -) - -func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { - mcpToolsOnce.Do(func() { - mcpTools = doGetMCPTools(ctx, permissions, cfg) - }) - return mcpTools +// CloseMCPClients closes all MCP clients. This should be called during application shutdown. +func CloseMCPClients() { + for c := range mcpClients.Seq() { + _ = c.Close() + } } func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { @@ -177,42 +143,59 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con wg.Add(1) go func(name string, m config.MCPConfig) { defer wg.Done() - switch m.Type { - case config.MCPStdio: - c, err := client.NewStdioMCPClient( - m.Command, - m.ResolvedEnv(), - m.Args..., - ) - if err != nil { - slog.Error("error creating mcp client", "error", err) - return - } - - result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - case config.MCPHttp: - c, err := client.NewStreamableHttpClient( - m.URL, - transport.WithHTTPHeaders(m.ResolvedHeaders()), - ) - if err != nil { - slog.Error("error creating mcp client", "error", err) - return - } - result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) - case config.MCPSse: - c, err := client.NewSSEMCPClient( - m.URL, - client.WithHeaders(m.ResolvedHeaders()), - ) - if err != nil { - slog.Error("error creating mcp client", "error", err) - return - } - result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...) + c, err := doGetClient(m) + if err != nil { + slog.Error("error creating mcp client", "error", err) + return } + if err := doInitClient(ctx, name, c); err != nil { + slog.Error("error initializing mcp client", "error", err) + return + } + result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...) }(name, m) } wg.Wait() return slices.Collect(result.Seq()) } + +func doInitClient(ctx context.Context, name string, c *client.Client) error { + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "Crush", + Version: version.Version, + }, + }, + } + if _, err := c.Initialize(ctx, initRequest); err != nil { + c.Close() + return err + } + mcpClients.Set(name, c) + return nil +} + +func doGetClient(m config.MCPConfig) (*client.Client, error) { + switch m.Type { + case config.MCPStdio: + return client.NewStdioMCPClient( + m.Command, + m.ResolvedEnv(), + m.Args..., + ) + case config.MCPHttp: + return client.NewStreamableHttpClient( + m.URL, + transport.WithHTTPHeaders(m.ResolvedHeaders()), + ) + case config.MCPSse: + return client.NewSSEMCPClient( + m.URL, + client.WithHeaders(m.ResolvedHeaders()), + ) + default: + return nil, fmt.Errorf("unsupported mcp type: %s", m.Type) + } +}