mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
fix: fix mcp clients
This commit is contained in:
@@ -256,6 +256,7 @@ func (app *App) InitCoderAgent() error {
|
||||
}
|
||||
var err error
|
||||
app.CoderAgent, err = agent.NewAgent(
|
||||
app.globalCtx,
|
||||
coderAgentCfg,
|
||||
app.Permissions,
|
||||
app.Sessions,
|
||||
|
||||
@@ -67,6 +67,7 @@ type agent struct {
|
||||
agentCfg config.Agent
|
||||
sessions session.Service
|
||||
messages message.Service
|
||||
mcpTools []McpTool
|
||||
|
||||
tools *csync.LazySlice[tools.BaseTool]
|
||||
|
||||
@@ -86,6 +87,7 @@ var agentPromptMap = map[string]prompt.PromptID{
|
||||
}
|
||||
|
||||
func NewAgent(
|
||||
ctx context.Context,
|
||||
agentCfg config.Agent,
|
||||
// These services are needed in the tools
|
||||
permissions permission.Service,
|
||||
@@ -94,7 +96,6 @@ func NewAgent(
|
||||
history history.Service,
|
||||
lspClients map[string]*lsp.Client,
|
||||
) (Service, error) {
|
||||
ctx := context.Background()
|
||||
cfg := config.Get()
|
||||
|
||||
var agentTool tools.BaseTool
|
||||
@@ -103,7 +104,7 @@ func NewAgent(
|
||||
if taskAgentCfg.ID == "" {
|
||||
return nil, fmt.Errorf("task agent not found in config")
|
||||
}
|
||||
taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
|
||||
taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create task agent: %w", err)
|
||||
}
|
||||
|
||||
@@ -20,9 +20,10 @@ import (
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
type mcpTool struct {
|
||||
type McpTool struct {
|
||||
mcpName string
|
||||
tool mcp.Tool
|
||||
client MCPClient
|
||||
mcpConfig config.MCPConfig
|
||||
permissions permission.Service
|
||||
workingDir string
|
||||
@@ -38,11 +39,11 @@ type MCPClient interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (b *mcpTool) Name() string {
|
||||
func (b *McpTool) Name() string {
|
||||
return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
|
||||
}
|
||||
|
||||
func (b *mcpTool) Info() tools.ToolInfo {
|
||||
func (b *McpTool) Info() tools.ToolInfo {
|
||||
required := b.tool.InputSchema.Required
|
||||
if required == nil {
|
||||
required = make([]string, 0)
|
||||
@@ -56,7 +57,6 @@ func (b *mcpTool) Info() tools.ToolInfo {
|
||||
}
|
||||
|
||||
func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
|
||||
defer c.Close()
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
@@ -93,7 +93,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t
|
||||
return tools.NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
|
||||
func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
|
||||
sessionID, messageID := tools.GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
|
||||
@@ -114,43 +114,13 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
|
||||
return tools.ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
switch b.mcpConfig.Type {
|
||||
case config.MCPStdio:
|
||||
c, err := client.NewStdioMCPClient(
|
||||
b.mcpConfig.Command,
|
||||
b.mcpConfig.ResolvedEnv(),
|
||||
b.mcpConfig.Args...,
|
||||
)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return runTool(ctx, c, b.tool.Name, params.Input)
|
||||
case config.MCPHttp:
|
||||
c, err := client.NewStreamableHttpClient(
|
||||
b.mcpConfig.URL,
|
||||
transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
|
||||
)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return runTool(ctx, c, b.tool.Name, params.Input)
|
||||
case config.MCPSse:
|
||||
c, err := client.NewSSEMCPClient(
|
||||
b.mcpConfig.URL,
|
||||
client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
|
||||
)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return runTool(ctx, c, b.tool.Name, params.Input)
|
||||
}
|
||||
|
||||
return tools.NewTextErrorResponse("invalid mcp type"), nil
|
||||
return runTool(ctx, b.client, b.tool.Name, params.Input)
|
||||
}
|
||||
|
||||
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
|
||||
return &mcpTool{
|
||||
func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
|
||||
return &McpTool{
|
||||
mcpName: name,
|
||||
client: c,
|
||||
tool: tool,
|
||||
mcpConfig: mcpConfig,
|
||||
permissions: permissions,
|
||||
@@ -179,9 +149,8 @@ func getTools(ctx context.Context, name string, m config.MCPConfig, permissions
|
||||
return stdioTools
|
||||
}
|
||||
for _, t := range tools.Tools {
|
||||
stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
|
||||
stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir))
|
||||
}
|
||||
defer c.Close()
|
||||
return stdioTools
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user