mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
fix: do not init MCP client on every tool request (#459)
* fix: do not init mcp client on every call Right now it inits each mcp client multiple times, one when discovering tools at startup, and then every time we call any tools. This makes it so we reuse the client from startup Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com> * wip Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com> * refactor: even better approach * fix: unused param * refactor: more improvements * fix: if list tools fails, remove client * fix: improve slice * chore: smaller changes --------- Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
70f479e5a2
commit
cd3ef8dbd4
@@ -268,6 +268,10 @@ func (app *App) InitCoderAgent() error {
|
|||||||
slog.Error("Failed to create coder agent", "err", err)
|
slog.Error("Failed to create coder agent", "err", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add MCP client cleanup to shutdown process
|
||||||
|
app.cleanupFuncs = append(app.cleanupFuncs, agent.CloseMCPClients)
|
||||||
|
|
||||||
setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
|
setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -189,7 +189,9 @@ func NewAgent(
|
|||||||
tools.NewWriteTool(lspClients, permissions, history, cwd),
|
tools.NewWriteTool(lspClients, permissions, history, cwd),
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpTools := GetMCPTools(ctx, permissions, cfg)
|
mcpToolsOnce.Do(func() {
|
||||||
|
mcpTools = doGetMCPTools(ctx, permissions, cfg)
|
||||||
|
})
|
||||||
allTools = append(allTools, mcpTools...)
|
allTools = append(allTools, mcpTools...)
|
||||||
|
|
||||||
if len(lspClients) > 0 {
|
if len(lspClients) > 0 {
|
||||||
|
|||||||
@@ -11,34 +11,28 @@ import (
|
|||||||
"github.com/charmbracelet/crush/internal/config"
|
"github.com/charmbracelet/crush/internal/config"
|
||||||
"github.com/charmbracelet/crush/internal/csync"
|
"github.com/charmbracelet/crush/internal/csync"
|
||||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||||
|
"github.com/charmbracelet/crush/internal/version"
|
||||||
|
|
||||||
"github.com/charmbracelet/crush/internal/permission"
|
"github.com/charmbracelet/crush/internal/permission"
|
||||||
"github.com/charmbracelet/crush/internal/version"
|
|
||||||
|
|
||||||
"github.com/mark3labs/mcp-go/client"
|
"github.com/mark3labs/mcp-go/client"
|
||||||
"github.com/mark3labs/mcp-go/client/transport"
|
"github.com/mark3labs/mcp-go/client/transport"
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mcpToolsOnce sync.Once
|
||||||
|
mcpTools []tools.BaseTool
|
||||||
|
mcpClients = csync.NewMap[string, *client.Client]()
|
||||||
|
)
|
||||||
|
|
||||||
type McpTool struct {
|
type McpTool struct {
|
||||||
mcpName string
|
mcpName string
|
||||||
tool mcp.Tool
|
tool mcp.Tool
|
||||||
client MCPClient
|
|
||||||
mcpConfig config.MCPConfig
|
|
||||||
permissions permission.Service
|
permissions permission.Service
|
||||||
workingDir string
|
workingDir string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MCPClient interface {
|
|
||||||
Initialize(
|
|
||||||
ctx context.Context,
|
|
||||||
request mcp.InitializeRequest,
|
|
||||||
) (*mcp.InitializeResult, error)
|
|
||||||
ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
|
|
||||||
CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *McpTool) Name() string {
|
func (b *McpTool) Name() string {
|
||||||
return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
|
return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
|
||||||
}
|
}
|
||||||
@@ -56,27 +50,21 @@ func (b *McpTool) Info() tools.ToolInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
|
func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
|
||||||
initRequest := mcp.InitializeRequest{}
|
|
||||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
|
||||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
|
||||||
Name: "Crush",
|
|
||||||
Version: version.Version,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := c.Initialize(ctx, initRequest)
|
|
||||||
if err != nil {
|
|
||||||
return tools.NewTextErrorResponse(err.Error()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
toolRequest := mcp.CallToolRequest{}
|
|
||||||
toolRequest.Params.Name = toolName
|
|
||||||
var args map[string]any
|
var args map[string]any
|
||||||
if err = json.Unmarshal([]byte(input), &args); err != nil {
|
if err := json.Unmarshal([]byte(input), &args); err != nil {
|
||||||
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||||
}
|
}
|
||||||
toolRequest.Params.Arguments = args
|
c, ok := mcpClients.Get(name)
|
||||||
result, err := c.CallTool(ctx, toolRequest)
|
if !ok {
|
||||||
|
return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
|
||||||
|
}
|
||||||
|
result, err := c.CallTool(ctx, mcp.CallToolRequest{
|
||||||
|
Params: mcp.CallToolParams{
|
||||||
|
Name: toolName,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tools.NewTextErrorResponse(err.Error()), nil
|
return tools.NewTextErrorResponse(err.Error()), nil
|
||||||
}
|
}
|
||||||
@@ -114,56 +102,34 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
|
|||||||
return tools.ToolResponse{}, permission.ErrorPermissionDenied
|
return tools.ToolResponse{}, permission.ErrorPermissionDenied
|
||||||
}
|
}
|
||||||
|
|
||||||
return runTool(ctx, b.client, b.tool.Name, params.Input)
|
return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
|
func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
|
||||||
return &McpTool{
|
result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
|
||||||
mcpName: name,
|
|
||||||
client: c,
|
|
||||||
tool: tool,
|
|
||||||
mcpConfig: mcpConfig,
|
|
||||||
permissions: permissions,
|
|
||||||
workingDir: workingDir,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
|
|
||||||
var stdioTools []tools.BaseTool
|
|
||||||
initRequest := mcp.InitializeRequest{}
|
|
||||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
|
||||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
|
||||||
Name: "Crush",
|
|
||||||
Version: version.Version,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := c.Initialize(ctx, initRequest)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error initializing mcp client", "error", err)
|
|
||||||
return stdioTools
|
|
||||||
}
|
|
||||||
toolsRequest := mcp.ListToolsRequest{}
|
|
||||||
tools, err := c.ListTools(ctx, toolsRequest)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("error listing tools", "error", err)
|
slog.Error("error listing tools", "error", err)
|
||||||
return stdioTools
|
c.Close()
|
||||||
|
mcpClients.Del(name)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
for _, t := range tools.Tools {
|
mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
|
||||||
stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir))
|
for _, tool := range result.Tools {
|
||||||
|
mcpTools = append(mcpTools, &McpTool{
|
||||||
|
mcpName: name,
|
||||||
|
tool: tool,
|
||||||
|
permissions: permissions,
|
||||||
|
workingDir: workingDir,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return stdioTools
|
return mcpTools
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
|
||||||
mcpToolsOnce sync.Once
|
func CloseMCPClients() {
|
||||||
mcpTools []tools.BaseTool
|
for c := range mcpClients.Seq() {
|
||||||
)
|
_ = c.Close()
|
||||||
|
}
|
||||||
func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
|
|
||||||
mcpToolsOnce.Do(func() {
|
|
||||||
mcpTools = doGetMCPTools(ctx, permissions, cfg)
|
|
||||||
})
|
|
||||||
return mcpTools
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
|
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
|
||||||
@@ -177,42 +143,59 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(name string, m config.MCPConfig) {
|
go func(name string, m config.MCPConfig) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
switch m.Type {
|
c, err := doGetClient(m)
|
||||||
case config.MCPStdio:
|
if err != nil {
|
||||||
c, err := client.NewStdioMCPClient(
|
slog.Error("error creating mcp client", "error", err)
|
||||||
m.Command,
|
return
|
||||||
m.ResolvedEnv(),
|
|
||||||
m.Args...,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error creating mcp client", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
|
||||||
case config.MCPHttp:
|
|
||||||
c, err := client.NewStreamableHttpClient(
|
|
||||||
m.URL,
|
|
||||||
transport.WithHTTPHeaders(m.ResolvedHeaders()),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error creating mcp client", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
|
||||||
case config.MCPSse:
|
|
||||||
c, err := client.NewSSEMCPClient(
|
|
||||||
m.URL,
|
|
||||||
client.WithHeaders(m.ResolvedHeaders()),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error creating mcp client", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
|
|
||||||
}
|
}
|
||||||
|
if err := doInitClient(ctx, name, c); err != nil {
|
||||||
|
slog.Error("error initializing mcp client", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
|
||||||
}(name, m)
|
}(name, m)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return slices.Collect(result.Seq())
|
return slices.Collect(result.Seq())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func doInitClient(ctx context.Context, name string, c *client.Client) error {
|
||||||
|
initRequest := mcp.InitializeRequest{
|
||||||
|
Params: mcp.InitializeParams{
|
||||||
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||||
|
ClientInfo: mcp.Implementation{
|
||||||
|
Name: "Crush",
|
||||||
|
Version: version.Version,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, err := c.Initialize(ctx, initRequest); err != nil {
|
||||||
|
c.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mcpClients.Set(name, c)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func doGetClient(m config.MCPConfig) (*client.Client, error) {
|
||||||
|
switch m.Type {
|
||||||
|
case config.MCPStdio:
|
||||||
|
return client.NewStdioMCPClient(
|
||||||
|
m.Command,
|
||||||
|
m.ResolvedEnv(),
|
||||||
|
m.Args...,
|
||||||
|
)
|
||||||
|
case config.MCPHttp:
|
||||||
|
return client.NewStreamableHttpClient(
|
||||||
|
m.URL,
|
||||||
|
transport.WithHTTPHeaders(m.ResolvedHeaders()),
|
||||||
|
)
|
||||||
|
case config.MCPSse:
|
||||||
|
return client.NewSSEMCPClient(
|
||||||
|
m.URL,
|
||||||
|
client.WithHeaders(m.ResolvedHeaders()),
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user