mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
chore(deps): update mcp-go (#155)
* chore(deps): update mcp-go Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com> * fix: vendoring Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com> --------- Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
531b3fd44f
commit
3e820ececc
2
go.mod
2
go.mod
@@ -29,7 +29,7 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.8.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/mark3labs/mcp-go v0.32.0
|
||||
github.com/mark3labs/mcp-go v0.33.0
|
||||
github.com/muesli/termenv v0.16.0
|
||||
github.com/ncruces/go-sqlite3 v0.25.0
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
|
||||
4
go.sum
4
go.sum
@@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8=
|
||||
github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
|
||||
github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc=
|
||||
github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
|
||||
102
vendor/github.com/mark3labs/mcp-go/client/client.go
generated
vendored
102
vendor/github.com/mark3labs/mcp-go/client/client.go
generated
vendored
@@ -22,6 +22,7 @@ type Client struct {
|
||||
requestID atomic.Int64
|
||||
clientCapabilities mcp.ClientCapabilities
|
||||
serverCapabilities mcp.ServerCapabilities
|
||||
samplingHandler SamplingHandler
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
@@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSamplingHandler sets the sampling handler for the client.
|
||||
// When set, the client will declare sampling capability during initialization.
|
||||
func WithSamplingHandler(handler SamplingHandler) ClientOption {
|
||||
return func(c *Client) {
|
||||
c.samplingHandler = handler
|
||||
}
|
||||
}
|
||||
|
||||
// WithSession assumes a MCP Session has already been initialized
|
||||
func WithSession() ClientOption {
|
||||
return func(c *Client) {
|
||||
c.initialized = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient creates a new MCP client with the given transport.
|
||||
// Usage:
|
||||
//
|
||||
@@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
|
||||
handler(notification)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up request handler for bidirectional communication (e.g., sampling)
|
||||
if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
|
||||
bidirectional.SetRequestHandler(c.handleIncomingRequest)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -127,6 +149,12 @@ func (c *Client) Initialize(
|
||||
ctx context.Context,
|
||||
request mcp.InitializeRequest,
|
||||
) (*mcp.InitializeResult, error) {
|
||||
// Merge client capabilities with sampling capability if handler is configured
|
||||
capabilities := request.Params.Capabilities
|
||||
if c.samplingHandler != nil {
|
||||
capabilities.Sampling = &struct{}{}
|
||||
}
|
||||
|
||||
// Ensure we send a params object with all required fields
|
||||
params := struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
@@ -135,7 +163,7 @@ func (c *Client) Initialize(
|
||||
}{
|
||||
ProtocolVersion: request.Params.ProtocolVersion,
|
||||
ClientInfo: request.Params.ClientInfo,
|
||||
Capabilities: request.Params.Capabilities, // Will be empty struct if not set
|
||||
Capabilities: capabilities,
|
||||
}
|
||||
|
||||
response, err := c.sendRequest(ctx, "initialize", params)
|
||||
@@ -398,6 +426,64 @@ func (c *Client) Complete(
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// handleIncomingRequest processes incoming requests from the server.
|
||||
// This is the main entry point for server-to-client requests like sampling.
|
||||
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
|
||||
switch request.Method {
|
||||
case string(mcp.MethodSamplingCreateMessage):
|
||||
return c.handleSamplingRequestTransport(ctx, request)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSamplingRequestTransport handles sampling requests at the transport level.
|
||||
func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
|
||||
if c.samplingHandler == nil {
|
||||
return nil, fmt.Errorf("no sampling handler configured")
|
||||
}
|
||||
|
||||
// Parse the request parameters
|
||||
var params mcp.CreateMessageParams
|
||||
if request.Params != nil {
|
||||
paramsBytes, err := json.Marshal(request.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal params: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the MCP request
|
||||
mcpRequest := mcp.CreateMessageRequest{
|
||||
Request: mcp.Request{
|
||||
Method: string(mcp.MethodSamplingCreateMessage),
|
||||
},
|
||||
CreateMessageParams: params,
|
||||
}
|
||||
|
||||
// Call the sampling handler
|
||||
result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Marshal the result
|
||||
resultBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal result: %w", err)
|
||||
}
|
||||
|
||||
// Create the transport response
|
||||
response := &transport.JSONRPCResponse{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: request.ID,
|
||||
Result: json.RawMessage(resultBytes),
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
func listByPage[T any](
|
||||
ctx context.Context,
|
||||
client *Client,
|
||||
@@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
|
||||
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
|
||||
return c.clientCapabilities
|
||||
}
|
||||
|
||||
// GetSessionId returns the session ID of the transport.
|
||||
// If the transport does not support sessions, it returns an empty string.
|
||||
func (c *Client) GetSessionId() string {
|
||||
if c.transport == nil {
|
||||
return ""
|
||||
}
|
||||
return c.transport.GetSessionId()
|
||||
}
|
||||
|
||||
// IsInitialized returns true if the client has been initialized.
|
||||
func (c *Client) IsInitialized() bool {
|
||||
return c.initialized
|
||||
}
|
||||
|
||||
7
vendor/github.com/mark3labs/mcp-go/client/http.go
generated
vendored
7
vendor/github.com/mark3labs/mcp-go/client/http.go
generated
vendored
@@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
|
||||
}
|
||||
return NewClient(trans), nil
|
||||
clientOptions := make([]ClientOption, 0)
|
||||
sessionID := trans.GetSessionId()
|
||||
if sessionID != "" {
|
||||
clientOptions = append(clientOptions, WithSession())
|
||||
}
|
||||
return NewClient(trans, clientOptions...), nil
|
||||
}
|
||||
|
||||
20
vendor/github.com/mark3labs/mcp-go/client/sampling.go
generated
vendored
Normal file
20
vendor/github.com/mark3labs/mcp-go/client/sampling.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// SamplingHandler defines the interface for handling sampling requests from servers.
|
||||
// Clients can implement this interface to provide LLM sampling capabilities to servers.
|
||||
type SamplingHandler interface {
|
||||
// CreateMessage handles a sampling request from the server and returns the generated message.
|
||||
// The implementation should:
|
||||
// 1. Validate the request parameters
|
||||
// 2. Optionally prompt the user for approval (human-in-the-loop)
|
||||
// 3. Select an appropriate model based on preferences
|
||||
// 4. Generate the response using the selected model
|
||||
// 5. Return the result with model information and stop reason
|
||||
CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
|
||||
}
|
||||
22
vendor/github.com/mark3labs/mcp-go/client/stdio.go
generated
vendored
22
vendor/github.com/mark3labs/mcp-go/client/stdio.go
generated
vendored
@@ -19,10 +19,26 @@ func NewStdioMCPClient(
|
||||
env []string,
|
||||
args ...string,
|
||||
) (*Client, error) {
|
||||
return NewStdioMCPClientWithOptions(command, env, args)
|
||||
}
|
||||
|
||||
stdioTransport := transport.NewStdio(command, env, args...)
|
||||
err := stdioTransport.Start(context.Background())
|
||||
if err != nil {
|
||||
// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess.
|
||||
// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
|
||||
// Optional configuration functions can be provided to customize the transport before it starts,
|
||||
// such as setting a custom command function.
|
||||
//
|
||||
// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport.
|
||||
// Don't call the Start method manually.
|
||||
// This is for backward compatibility.
|
||||
func NewStdioMCPClientWithOptions(
|
||||
command string,
|
||||
env []string,
|
||||
args []string,
|
||||
opts ...transport.StdioOption,
|
||||
) (*Client, error) {
|
||||
stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...)
|
||||
|
||||
if err := stdioTransport.Start(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to start stdio transport: %w", err)
|
||||
}
|
||||
|
||||
|
||||
4
vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go
generated
vendored
4
vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go
generated
vendored
@@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc
|
||||
func (*InProcessTransport) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *InProcessTransport) GetSessionId() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
20
vendor/github.com/mark3labs/mcp-go/client/transport/interface.go
generated
vendored
20
vendor/github.com/mark3labs/mcp-go/client/transport/interface.go
generated
vendored
@@ -29,6 +29,22 @@ type Interface interface {
|
||||
|
||||
// Close the connection.
|
||||
Close() error
|
||||
|
||||
// GetSessionId returns the session ID of the transport.
|
||||
GetSessionId() string
|
||||
}
|
||||
|
||||
// RequestHandler defines a function that handles incoming requests from the server.
|
||||
type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error)
|
||||
|
||||
// BidirectionalInterface extends Interface to support incoming requests from the server.
|
||||
// This is used for features like sampling where the server can send requests to the client.
|
||||
type BidirectionalInterface interface {
|
||||
Interface
|
||||
|
||||
// SetRequestHandler sets the handler for incoming requests from the server.
|
||||
// The handler should process the request and return a response.
|
||||
SetRequestHandler(handler RequestHandler)
|
||||
}
|
||||
|
||||
type JSONRPCRequest struct {
|
||||
@@ -41,10 +57,10 @@ type JSONRPCRequest struct {
|
||||
type JSONRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID mcp.RequestId `json:"id"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
} `json:"error"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
6
vendor/github.com/mark3labs/mcp-go/client/transport/sse.go
generated
vendored
6
vendor/github.com/mark3labs/mcp-go/client/transport/sse.go
generated
vendored
@@ -428,6 +428,12 @@ func (c *SSE) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSessionId returns the session ID of the transport.
|
||||
// Since SSE does not maintain a session ID, it returns an empty string.
|
||||
func (c *SSE) GetSessionId() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
|
||||
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
|
||||
if c.endpoint == nil {
|
||||
|
||||
204
vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go
generated
vendored
204
vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go
generated
vendored
@@ -23,6 +23,7 @@ type Stdio struct {
|
||||
env []string
|
||||
|
||||
cmd *exec.Cmd
|
||||
cmdFunc CommandFunc
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr io.ReadCloser
|
||||
@@ -31,6 +32,28 @@ type Stdio struct {
|
||||
done chan struct{}
|
||||
onNotification func(mcp.JSONRPCNotification)
|
||||
notifyMu sync.RWMutex
|
||||
onRequest RequestHandler
|
||||
requestMu sync.RWMutex
|
||||
ctx context.Context
|
||||
ctxMu sync.RWMutex
|
||||
}
|
||||
|
||||
// StdioOption defines a function that configures a Stdio transport instance.
|
||||
// Options can be used to customize the behavior of the transport before it starts,
|
||||
// such as setting a custom command function.
|
||||
type StdioOption func(*Stdio)
|
||||
|
||||
// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess.
|
||||
// It can be used to apply sandboxing, custom environment control, working directories, etc.
|
||||
type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error)
|
||||
|
||||
// WithCommandFunc sets a custom command factory function for the stdio transport.
|
||||
// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess,
|
||||
// allowing control over attributes like environment, working directory, and system-level sandboxing.
|
||||
func WithCommandFunc(f CommandFunc) StdioOption {
|
||||
return func(s *Stdio) {
|
||||
s.cmdFunc = f
|
||||
}
|
||||
}
|
||||
|
||||
// NewIO returns a new stdio-based transport using existing input, output, and
|
||||
@@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
|
||||
|
||||
responses: make(map[string]chan *JSONRPCResponse),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,20 +79,43 @@ func NewStdio(
|
||||
env []string,
|
||||
args ...string,
|
||||
) *Stdio {
|
||||
return NewStdioWithOptions(command, env, args)
|
||||
}
|
||||
|
||||
client := &Stdio{
|
||||
// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess.
|
||||
// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
|
||||
// Returns an error if the subprocess cannot be started or the pipes cannot be created.
|
||||
// Optional configuration functions can be provided to customize the transport before it starts,
|
||||
// such as setting a custom command factory.
|
||||
func NewStdioWithOptions(
|
||||
command string,
|
||||
env []string,
|
||||
args []string,
|
||||
opts ...StdioOption,
|
||||
) *Stdio {
|
||||
s := &Stdio{
|
||||
command: command,
|
||||
args: args,
|
||||
env: env,
|
||||
|
||||
responses: make(map[string]chan *JSONRPCResponse),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
return client
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *Stdio) Start(ctx context.Context) error {
|
||||
// Store the context for use in request handling
|
||||
c.ctxMu.Lock()
|
||||
c.ctx = ctx
|
||||
c.ctxMu.Unlock()
|
||||
|
||||
if err := c.spawnCommand(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// spawnCommand spawns a new process running c.command.
|
||||
// spawnCommand spawns a new process running the configured command, args, and env.
|
||||
// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
|
||||
// otherwise, the default behavior uses exec.CommandContext with the merged environment.
|
||||
// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
|
||||
func (c *Stdio) spawnCommand(ctx context.Context) error {
|
||||
if c.command == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, c.command, c.args...)
|
||||
var cmd *exec.Cmd
|
||||
var err error
|
||||
|
||||
mergedEnv := os.Environ()
|
||||
mergedEnv = append(mergedEnv, c.env...)
|
||||
|
||||
cmd.Env = mergedEnv
|
||||
// Standard behavior if no command func present.
|
||||
if c.cmdFunc == nil {
|
||||
cmd = exec.CommandContext(ctx, c.command, c.args...)
|
||||
cmd.Env = append(os.Environ(), c.env...)
|
||||
} else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
@@ -148,6 +202,12 @@ func (c *Stdio) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSessionId returns the session ID of the transport.
|
||||
// Since stdio does not maintain a session ID, it returns an empty string.
|
||||
func (c *Stdio) GetSessionId() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetNotificationHandler sets the handler function to be called when a notification is received.
|
||||
// Only one handler can be set at a time; setting a new one replaces the previous handler.
|
||||
func (c *Stdio) SetNotificationHandler(
|
||||
@@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler(
|
||||
c.onNotification = handler
|
||||
}
|
||||
|
||||
// SetRequestHandler sets the handler function to be called when a request is received from the server.
|
||||
// This enables bidirectional communication for features like sampling.
|
||||
func (c *Stdio) SetRequestHandler(handler RequestHandler) {
|
||||
c.requestMu.Lock()
|
||||
defer c.requestMu.Unlock()
|
||||
c.onRequest = handler
|
||||
}
|
||||
|
||||
// readResponses continuously reads and processes responses from the server's stdout.
|
||||
// It handles both responses to requests and notifications, routing them appropriately.
|
||||
// Runs until the done channel is closed or an error occurs reading from stdout.
|
||||
@@ -175,13 +243,18 @@ func (c *Stdio) readResponses() {
|
||||
return
|
||||
}
|
||||
|
||||
var baseMessage JSONRPCResponse
|
||||
// First try to parse as a generic message to check for ID field
|
||||
var baseMessage struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID *mcp.RequestId `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle notification
|
||||
if baseMessage.ID.IsNil() {
|
||||
// If it has a method but no ID, it's a notification
|
||||
if baseMessage.Method != "" && baseMessage.ID == nil {
|
||||
var notification mcp.JSONRPCNotification
|
||||
if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
|
||||
continue
|
||||
@@ -194,15 +267,30 @@ func (c *Stdio) readResponses() {
|
||||
continue
|
||||
}
|
||||
|
||||
// If it has a method and an ID, it's an incoming request
|
||||
if baseMessage.Method != "" && baseMessage.ID != nil {
|
||||
var request JSONRPCRequest
|
||||
if err := json.Unmarshal([]byte(line), &request); err == nil {
|
||||
c.handleIncomingRequest(request)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, it's a response to our request
|
||||
var response JSONRPCResponse
|
||||
if err := json.Unmarshal([]byte(line), &response); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create string key for map lookup
|
||||
idKey := baseMessage.ID.String()
|
||||
idKey := response.ID.String()
|
||||
|
||||
c.mu.RLock()
|
||||
ch, exists := c.responses[idKey]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
ch <- &baseMessage
|
||||
ch <- &response
|
||||
c.mu.Lock()
|
||||
delete(c.responses, idKey)
|
||||
c.mu.Unlock()
|
||||
@@ -281,6 +369,96 @@ func (c *Stdio) SendNotification(
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleIncomingRequest processes incoming requests from the server.
|
||||
// It calls the registered request handler and sends the response back to the server.
|
||||
func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
|
||||
c.requestMu.RLock()
|
||||
handler := c.onRequest
|
||||
c.requestMu.RUnlock()
|
||||
|
||||
if handler == nil {
|
||||
// Send error response if no handler is configured
|
||||
errorResponse := JSONRPCResponse{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: request.ID,
|
||||
Error: &struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}{
|
||||
Code: mcp.METHOD_NOT_FOUND,
|
||||
Message: "No request handler configured",
|
||||
},
|
||||
}
|
||||
c.sendResponse(errorResponse)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle the request in a goroutine to avoid blocking
|
||||
go func() {
|
||||
c.ctxMu.RLock()
|
||||
ctx := c.ctx
|
||||
c.ctxMu.RUnlock()
|
||||
|
||||
// Check if context is already cancelled before processing
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errorResponse := JSONRPCResponse{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: request.ID,
|
||||
Error: &struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}{
|
||||
Code: mcp.INTERNAL_ERROR,
|
||||
Message: ctx.Err().Error(),
|
||||
},
|
||||
}
|
||||
c.sendResponse(errorResponse)
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
response, err := handler(ctx, request)
|
||||
|
||||
if err != nil {
|
||||
errorResponse := JSONRPCResponse{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: request.ID,
|
||||
Error: &struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}{
|
||||
Code: mcp.INTERNAL_ERROR,
|
||||
Message: err.Error(),
|
||||
},
|
||||
}
|
||||
c.sendResponse(errorResponse)
|
||||
return
|
||||
}
|
||||
|
||||
if response != nil {
|
||||
c.sendResponse(*response)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// sendResponse sends a response back to the server.
|
||||
func (c *Stdio) sendResponse(response JSONRPCResponse) {
|
||||
responseBytes, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
fmt.Printf("Error marshaling response: %v\n", err)
|
||||
return
|
||||
}
|
||||
responseBytes = append(responseBytes, '\n')
|
||||
|
||||
if _, err := c.stdin.Write(responseBytes); err != nil {
|
||||
fmt.Printf("Error writing response: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stderr returns a reader for the stderr output of the subprocess.
|
||||
// This can be used to capture error messages or logs from the subprocess.
|
||||
func (c *Stdio) Stderr() io.Reader {
|
||||
|
||||
347
vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go
generated
vendored
347
vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go
generated
vendored
@@ -17,10 +17,24 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/util"
|
||||
)
|
||||
|
||||
type StreamableHTTPCOption func(*StreamableHTTP)
|
||||
|
||||
// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
|
||||
// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
|
||||
// you should enable this option.
|
||||
//
|
||||
// It will establish a standalone long-live GET HTTP connection to the server.
|
||||
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
|
||||
// NOTICE: Even enabled, the server may not support this feature.
|
||||
func WithContinuousListening() StreamableHTTPCOption {
|
||||
return func(sc *StreamableHTTP) {
|
||||
sc.getListeningEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
|
||||
func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
|
||||
return func(sc *StreamableHTTP) {
|
||||
@@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(logger util.Logger) StreamableHTTPCOption {
|
||||
return func(sc *StreamableHTTP) {
|
||||
sc.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithSession creates a client with a pre-configured session
|
||||
func WithSession(sessionID string) StreamableHTTPCOption {
|
||||
return func(sc *StreamableHTTP) {
|
||||
sc.sessionID.Store(sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// StreamableHTTP implements Streamable HTTP transport.
|
||||
//
|
||||
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
|
||||
@@ -64,8 +91,6 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
|
||||
//
|
||||
// The current implementation does not support the following features:
|
||||
// - batching
|
||||
// - continuously listening for server notifications when no request is in flight
|
||||
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
|
||||
// - resuming stream
|
||||
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
|
||||
// - server -> client request
|
||||
@@ -74,9 +99,14 @@ type StreamableHTTP struct {
|
||||
httpClient *http.Client
|
||||
headers map[string]string
|
||||
headerFunc HTTPHeaderFunc
|
||||
logger util.Logger
|
||||
getListeningEnabled bool
|
||||
|
||||
sessionID atomic.Value // string
|
||||
|
||||
initialized chan struct{}
|
||||
initializedOnce sync.Once
|
||||
|
||||
notificationHandler func(mcp.JSONRPCNotification)
|
||||
notifyMu sync.RWMutex
|
||||
|
||||
@@ -99,12 +129,16 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
|
||||
httpClient: &http.Client{},
|
||||
headers: make(map[string]string),
|
||||
closed: make(chan struct{}),
|
||||
logger: util.DefaultLogger(),
|
||||
initialized: make(chan struct{}),
|
||||
}
|
||||
smc.sessionID.Store("") // set initial value to simplify later usage
|
||||
|
||||
for _, opt := range options {
|
||||
if opt != nil {
|
||||
opt(smc)
|
||||
}
|
||||
}
|
||||
|
||||
// If OAuth is configured, set the base URL for metadata discovery
|
||||
if smc.oauthHandler != nil {
|
||||
@@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
|
||||
|
||||
// Start initiates the HTTP connection to the server.
|
||||
func (c *StreamableHTTP) Start(ctx context.Context) error {
|
||||
// For Streamable HTTP, we don't need to establish a persistent connection
|
||||
// For Streamable HTTP, we don't need to establish a persistent connection by default
|
||||
if c.getListeningEnabled {
|
||||
go func() {
|
||||
select {
|
||||
case <-c.initialized:
|
||||
ctx, cancel := c.contextAwareOfClientClose(ctx)
|
||||
defer cancel()
|
||||
c.listenForever(ctx)
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error {
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to create close request\n: %v", err)
|
||||
c.logger.Errorf("failed to create close request: %v", err)
|
||||
return
|
||||
}
|
||||
req.Header.Set(headerKeySessionID, sessionId)
|
||||
res, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to send close request\n: %v", err)
|
||||
c.logger.Errorf("failed to send close request: %v", err)
|
||||
return
|
||||
}
|
||||
res.Body.Close()
|
||||
@@ -185,37 +232,103 @@ func (c *StreamableHTTP) SendRequest(
|
||||
request JSONRPCRequest,
|
||||
) (*JSONRPCResponse, error) {
|
||||
|
||||
// Create a combined context that could be canceled when the client is closed
|
||||
newCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
go func() {
|
||||
select {
|
||||
case <-c.closed:
|
||||
cancel()
|
||||
case <-newCtx.Done():
|
||||
// The original context was canceled, no need to do anything
|
||||
}
|
||||
}()
|
||||
ctx = newCtx
|
||||
|
||||
// Marshal request
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := c.contextAwareOfClientClose(ctx)
|
||||
defer cancel()
|
||||
|
||||
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
|
||||
// If the request is initialize, should not return a SessionTerminated error
|
||||
// It should be a genuine endpoint-routing issue.
|
||||
// ( Fall through to return StatusCode checking. )
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if we got an error response
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
|
||||
|
||||
// Handle OAuth unauthorized error
|
||||
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
|
||||
return nil, &OAuthAuthorizationRequiredError{
|
||||
Handler: c.oauthHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// handle error response
|
||||
var errResponse JSONRPCResponse
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := json.Unmarshal(body, &errResponse); err == nil {
|
||||
return &errResponse, nil
|
||||
}
|
||||
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
if request.Method == string(mcp.MethodInitialize) {
|
||||
// saved the received session ID in the response
|
||||
// empty session ID is allowed
|
||||
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
|
||||
c.sessionID.Store(sessionID)
|
||||
}
|
||||
|
||||
c.initializedOnce.Do(func() {
|
||||
close(c.initialized)
|
||||
})
|
||||
}
|
||||
|
||||
// Handle different response types
|
||||
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||
switch mediaType {
|
||||
case "application/json":
|
||||
// Single response
|
||||
var response JSONRPCResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// should not be a notification
|
||||
if response.ID.IsNil() {
|
||||
return nil, fmt.Errorf("response should contain RPC id: %v", response)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
|
||||
case "text/event-stream":
|
||||
// Server is using SSE for streaming responses
|
||||
return c.handleSSEResponse(ctx, resp.Body, false)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StreamableHTTP) sendHTTP(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
body io.Reader,
|
||||
acceptType string,
|
||||
) (resp *http.Response, err error) {
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json, text/event-stream")
|
||||
sessionID := c.sessionID.Load()
|
||||
req.Header.Set("Accept", acceptType)
|
||||
sessionID := c.sessionID.Load().(string)
|
||||
if sessionID != "" {
|
||||
req.Header.Set(headerKeySessionID, sessionID.(string))
|
||||
req.Header.Set(headerKeySessionID, sessionID)
|
||||
}
|
||||
for k, v := range c.headers {
|
||||
req.Header.Set(k, v)
|
||||
@@ -243,73 +356,24 @@ func (c *StreamableHTTP) SendRequest(
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := c.httpClient.Do(req)
|
||||
resp, err = c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if we got an error response
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
|
||||
// handle session closed
|
||||
// universal handling for session terminated
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
c.sessionID.CompareAndSwap(sessionID, "")
|
||||
return nil, fmt.Errorf("session terminated (404). need to re-initialize")
|
||||
return nil, ErrSessionTerminated
|
||||
}
|
||||
|
||||
// Handle OAuth unauthorized error
|
||||
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
|
||||
return nil, &OAuthAuthorizationRequiredError{
|
||||
Handler: c.oauthHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// handle error response
|
||||
var errResponse JSONRPCResponse
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := json.Unmarshal(body, &errResponse); err == nil {
|
||||
return &errResponse, nil
|
||||
}
|
||||
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
if request.Method == string(mcp.MethodInitialize) {
|
||||
// saved the received session ID in the response
|
||||
// empty session ID is allowed
|
||||
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
|
||||
c.sessionID.Store(sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle different response types
|
||||
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||
switch mediaType {
|
||||
case "application/json":
|
||||
// Single response
|
||||
var response JSONRPCResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// should not be a notification
|
||||
if response.ID.IsNil() {
|
||||
return nil, fmt.Errorf("response should contain RPC id: %v", response)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
|
||||
case "text/event-stream":
|
||||
// Server is using SSE for streaming responses
|
||||
return c.handleSSEResponse(ctx, resp.Body)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// handleSSEResponse processes an SSE stream for a specific request.
|
||||
// It returns the final result for the request once received, or an error.
|
||||
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
|
||||
// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
|
||||
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {
|
||||
|
||||
// Create a channel for this specific request
|
||||
responseChan := make(chan *JSONRPCResponse, 1)
|
||||
@@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
|
||||
|
||||
var message JSONRPCResponse
|
||||
if err := json.Unmarshal([]byte(data), &message); err != nil {
|
||||
fmt.Printf("failed to unmarshal message: %v\n", err)
|
||||
c.logger.Errorf("failed to unmarshal message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
|
||||
if message.ID.IsNil() {
|
||||
var notification mcp.JSONRPCNotification
|
||||
if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
|
||||
fmt.Printf("failed to unmarshal notification: %v\n", err)
|
||||
c.logger.Errorf("failed to unmarshal notification: %v", err)
|
||||
return
|
||||
}
|
||||
c.notifyMu.RLock()
|
||||
@@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
|
||||
return
|
||||
}
|
||||
|
||||
if !ignoreResponse {
|
||||
responseChan <- &message
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
@@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
fmt.Printf("SSE stream error: %v\n", err)
|
||||
c.logger.Errorf("SSE stream error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
ctx, cancel := c.contextAwareOfClientClose(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json, text/event-stream")
|
||||
if sessionID := c.sessionID.Load(); sessionID != "" {
|
||||
req.Header.Set(headerKeySessionID, sessionID.(string))
|
||||
}
|
||||
for k, v := range c.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// Add OAuth authorization if configured
|
||||
if c.oauthHandler != nil {
|
||||
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
|
||||
if err != nil {
|
||||
// If we get an authorization error, return a specific error that can be handled by the client
|
||||
if errors.Is(err, ErrOAuthAuthorizationRequired) {
|
||||
return &OAuthAuthorizationRequiredError{
|
||||
Handler: c.oauthHandler,
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to get authorization header: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
}
|
||||
|
||||
if c.headerFunc != nil {
|
||||
for k, v := range c.headerFunc(ctx) {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := c.httpClient.Do(req)
|
||||
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
|
||||
func (c *StreamableHTTP) IsOAuthEnabled() bool {
|
||||
return c.oauthHandler != nil
|
||||
}
|
||||
|
||||
func (c *StreamableHTTP) listenForever(ctx context.Context) {
|
||||
c.logger.Infof("listening to server forever")
|
||||
for {
|
||||
err := c.createGETConnectionToServer(ctx)
|
||||
if errors.Is(err, ErrGetMethodNotAllowed) {
|
||||
// server does not support listening
|
||||
c.logger.Errorf("server does not support listening")
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
|
||||
}
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
|
||||
ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
|
||||
|
||||
retryInterval = 1 * time.Second // a variable is convenient for testing
|
||||
)
|
||||
|
||||
func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
|
||||
|
||||
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if we got an error response
|
||||
if resp.StatusCode == http.StatusMethodNotAllowed {
|
||||
return ErrGetMethodNotAllowed
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
// handle SSE response
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "text/event-stream" {
|
||||
return fmt.Errorf("unexpected content type: %s", contentType)
|
||||
}
|
||||
|
||||
// When ignoreResponse is true, the function will never return expect context is done.
|
||||
// NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response
|
||||
// messages. To be more compatible, we should handle this response, however, as the transport layer is message-based,
|
||||
// currently, there is no convenient way to handle this response.
|
||||
// So we ignore the response here. It's not a bug, but may be not compatible with other SDKs.
|
||||
_, err = c.handleSSEResponse(ctx, resp.Body, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle SSE response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
newCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
select {
|
||||
case <-c.closed:
|
||||
cancel()
|
||||
case <-newCtx.Done():
|
||||
// The original context was canceled
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
return newCtx, cancel
|
||||
}
|
||||
|
||||
106
vendor/github.com/mark3labs/mcp-go/mcp/tools.go
generated
vendored
106
vendor/github.com/mark3labs/mcp-go/mcp/tools.go
generated
vendored
@@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption {
|
||||
}
|
||||
}
|
||||
|
||||
// Items defines the schema for array items
|
||||
// Items defines the schema for array items.
|
||||
// Accepts any schema definition for maximum flexibility.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// Items(map[string]any{
|
||||
// "type": "object",
|
||||
// "properties": map[string]any{
|
||||
// "name": map[string]any{"type": "string"},
|
||||
// "age": map[string]any{"type": "number"},
|
||||
// },
|
||||
// })
|
||||
//
|
||||
// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead.
|
||||
func Items(schema any) PropertyOption {
|
||||
return func(schemaMap map[string]any) {
|
||||
schemaMap["items"] = schema
|
||||
@@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption {
|
||||
schema["uniqueItems"] = unique
|
||||
}
|
||||
}
|
||||
|
||||
// WithStringItems configures an array's items to be of type string.
|
||||
//
|
||||
// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern()
|
||||
// Note: Options like Required() are not valid for item schemas and will be ignored.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mcp.WithArray("tags", mcp.WithStringItems())
|
||||
// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue")))
|
||||
// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50)))
|
||||
//
|
||||
// Limitations: Only supports simple string arrays. Use Items() for complex objects.
|
||||
func WithStringItems(opts ...PropertyOption) PropertyOption {
|
||||
return func(schema map[string]any) {
|
||||
itemSchema := map[string]any{
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(itemSchema)
|
||||
}
|
||||
|
||||
schema["items"] = itemSchema
|
||||
}
|
||||
}
|
||||
|
||||
// WithStringEnumItems configures an array's items to be of type string with a specified enum.
|
||||
// Example:
|
||||
//
|
||||
// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"}))
|
||||
//
|
||||
// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility.
|
||||
func WithStringEnumItems(values []string) PropertyOption {
|
||||
return func(schema map[string]any) {
|
||||
schema["items"] = map[string]any{
|
||||
"type": "string",
|
||||
"enum": values,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithNumberItems configures an array's items to be of type number.
|
||||
//
|
||||
// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf()
|
||||
// Note: Options like Required() are not valid for item schemas and will be ignored.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100)))
|
||||
// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0)))
|
||||
//
|
||||
// Limitations: Only supports simple number arrays. Use Items() for complex objects.
|
||||
func WithNumberItems(opts ...PropertyOption) PropertyOption {
|
||||
return func(schema map[string]any) {
|
||||
itemSchema := map[string]any{
|
||||
"type": "number",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(itemSchema)
|
||||
}
|
||||
|
||||
schema["items"] = itemSchema
|
||||
}
|
||||
}
|
||||
|
||||
// WithBooleanItems configures an array's items to be of type boolean.
|
||||
//
|
||||
// Supported options: Description(), DefaultBool()
|
||||
// Note: Options like Required() are not valid for item schemas and will be ignored.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// mcp.WithArray("flags", mcp.WithBooleanItems())
|
||||
// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions")))
|
||||
//
|
||||
// Limitations: Only supports simple boolean arrays. Use Items() for complex objects.
|
||||
func WithBooleanItems(opts ...PropertyOption) PropertyOption {
|
||||
return func(schema map[string]any) {
|
||||
itemSchema := map[string]any{
|
||||
"type": "boolean",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(itemSchema)
|
||||
}
|
||||
|
||||
schema["items"] = itemSchema
|
||||
}
|
||||
}
|
||||
|
||||
21
vendor/github.com/mark3labs/mcp-go/mcp/types.go
generated
vendored
21
vendor/github.com/mark3labs/mcp-go/mcp/types.go
generated
vendored
@@ -763,6 +763,11 @@ const (
|
||||
|
||||
/* Sampling */
|
||||
|
||||
const (
|
||||
// MethodSamplingCreateMessage allows servers to request LLM completions from clients
|
||||
MethodSamplingCreateMessage MCPMethod = "sampling/createMessage"
|
||||
)
|
||||
|
||||
// CreateMessageRequest is a request from the server to sample an LLM via the
|
||||
// client. The client has full discretion over which model to select. The client
|
||||
// should also inform the user before beginning sampling, to allow them to inspect
|
||||
@@ -865,6 +870,22 @@ type AudioContent struct {
|
||||
|
||||
func (AudioContent) isContent() {}
|
||||
|
||||
// ResourceLink represents a link to a resource that the client can access.
|
||||
type ResourceLink struct {
|
||||
Annotated
|
||||
Type string `json:"type"` // Must be "resource_link"
|
||||
// The URI of the resource.
|
||||
URI string `json:"uri"`
|
||||
// The name of the resource.
|
||||
Name string `json:"name"`
|
||||
// The description of the resource.
|
||||
Description string `json:"description"`
|
||||
// The MIME type of the resource.
|
||||
MIMEType string `json:"mimeType"`
|
||||
}
|
||||
|
||||
func (ResourceLink) isContent() {}
|
||||
|
||||
// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
|
||||
//
|
||||
// It is up to the client how best to render embedded resources for the
|
||||
|
||||
21
vendor/github.com/mark3labs/mcp-go/mcp/utils.go
generated
vendored
21
vendor/github.com/mark3labs/mcp-go/mcp/utils.go
generated
vendored
@@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a new ResourceLink
|
||||
func NewResourceLink(uri, name, description, mimeType string) ResourceLink {
|
||||
return ResourceLink{
|
||||
Type: "resource_link",
|
||||
URI: uri,
|
||||
Name: name,
|
||||
Description: description,
|
||||
MIMEType: mimeType,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a new EmbeddedResource
|
||||
func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
|
||||
return EmbeddedResource{
|
||||
@@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) {
|
||||
}
|
||||
return NewAudioContent(data, mimeType), nil
|
||||
|
||||
case "resource_link":
|
||||
uri := ExtractString(contentMap, "uri")
|
||||
name := ExtractString(contentMap, "name")
|
||||
description := ExtractString(contentMap, "description")
|
||||
mimeType := ExtractString(contentMap, "mimeType")
|
||||
if uri == "" || name == "" {
|
||||
return nil, fmt.Errorf("resource_link uri or name is missing")
|
||||
}
|
||||
return NewResourceLink(uri, name, description, mimeType), nil
|
||||
|
||||
case "resource":
|
||||
resourceMap := ExtractMap(contentMap, "resource")
|
||||
if resourceMap == nil {
|
||||
|
||||
37
vendor/github.com/mark3labs/mcp-go/server/sampling.go
generated
vendored
Normal file
37
vendor/github.com/mark3labs/mcp-go/server/sampling.go
generated
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// EnableSampling enables sampling capabilities for the server.
|
||||
// This allows the server to send sampling requests to clients that support it.
|
||||
func (s *MCPServer) EnableSampling() {
|
||||
s.capabilitiesMu.Lock()
|
||||
defer s.capabilitiesMu.Unlock()
|
||||
}
|
||||
|
||||
// RequestSampling sends a sampling request to the client.
|
||||
// The client must have declared sampling capability during initialization.
|
||||
func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
|
||||
session := ClientSessionFromContext(ctx)
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("no active session")
|
||||
}
|
||||
|
||||
// Check if the session supports sampling requests
|
||||
if samplingSession, ok := session.(SessionWithSampling); ok {
|
||||
return samplingSession.RequestSampling(ctx, request)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("session does not support sampling")
|
||||
}
|
||||
|
||||
// SessionWithSampling extends ClientSession to support sampling requests.
|
||||
type SessionWithSampling interface {
|
||||
ClientSession
|
||||
RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
|
||||
}
|
||||
170
vendor/github.com/mark3labs/mcp-go/server/stdio.go
generated
vendored
170
vendor/github.com/mark3labs/mcp-go/server/stdio.go
generated
vendored
@@ -9,6 +9,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
@@ -55,6 +56,17 @@ type stdioSession struct {
|
||||
initialized atomic.Bool
|
||||
loggingLevel atomic.Value
|
||||
clientInfo atomic.Value // stores session-specific client info
|
||||
writer io.Writer // for sending requests to client
|
||||
requestID atomic.Int64 // for generating unique request IDs
|
||||
mu sync.RWMutex // protects writer
|
||||
pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
|
||||
pendingMu sync.RWMutex // protects pendingRequests
|
||||
}
|
||||
|
||||
// samplingResponse represents a response to a sampling request
|
||||
type samplingResponse struct {
|
||||
result *mcp.CreateMessageResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stdioSession) SessionID() string {
|
||||
@@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
|
||||
return level.(mcp.LoggingLevel)
|
||||
}
|
||||
|
||||
// RequestSampling sends a sampling request to the client and waits for the response.
|
||||
func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
|
||||
s.mu.RLock()
|
||||
writer := s.writer
|
||||
s.mu.RUnlock()
|
||||
|
||||
if writer == nil {
|
||||
return nil, fmt.Errorf("no writer available for sending requests")
|
||||
}
|
||||
|
||||
// Generate a unique request ID
|
||||
id := s.requestID.Add(1)
|
||||
|
||||
// Create a response channel for this request
|
||||
responseChan := make(chan *samplingResponse, 1)
|
||||
s.pendingMu.Lock()
|
||||
s.pendingRequests[id] = responseChan
|
||||
s.pendingMu.Unlock()
|
||||
|
||||
// Cleanup function to remove the pending request
|
||||
cleanup := func() {
|
||||
s.pendingMu.Lock()
|
||||
delete(s.pendingRequests, id)
|
||||
s.pendingMu.Unlock()
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
// Create the JSON-RPC request
|
||||
jsonRPCRequest := struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params mcp.CreateMessageParams `json:"params"`
|
||||
}{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: id,
|
||||
Method: string(mcp.MethodSamplingCreateMessage),
|
||||
Params: request.CreateMessageParams,
|
||||
}
|
||||
|
||||
// Marshal and send the request
|
||||
requestBytes, err := json.Marshal(jsonRPCRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
|
||||
}
|
||||
requestBytes = append(requestBytes, '\n')
|
||||
|
||||
if _, err := writer.Write(requestBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to write sampling request: %w", err)
|
||||
}
|
||||
|
||||
// Wait for the response or context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case response := <-responseChan:
|
||||
if response.err != nil {
|
||||
return nil, response.err
|
||||
}
|
||||
return response.result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetWriter sets the writer for sending requests to the client.
|
||||
func (s *stdioSession) SetWriter(writer io.Writer) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.writer = writer
|
||||
}
|
||||
|
||||
var (
|
||||
_ ClientSession = (*stdioSession)(nil)
|
||||
_ SessionWithLogging = (*stdioSession)(nil)
|
||||
_ SessionWithClientInfo = (*stdioSession)(nil)
|
||||
_ SessionWithSampling = (*stdioSession)(nil)
|
||||
)
|
||||
|
||||
var stdioSessionInstance = stdioSession{
|
||||
notifications: make(chan mcp.JSONRPCNotification, 100),
|
||||
pendingRequests: make(map[int64]chan *samplingResponse),
|
||||
}
|
||||
|
||||
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
|
||||
@@ -224,6 +308,9 @@ func (s *StdioServer) Listen(
|
||||
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
|
||||
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
|
||||
|
||||
// Set the writer for sending requests to the client
|
||||
stdioSessionInstance.SetWriter(stdout)
|
||||
|
||||
// Add in any custom context.
|
||||
if s.contextFunc != nil {
|
||||
ctx = s.contextFunc(ctx)
|
||||
@@ -256,7 +343,29 @@ func (s *StdioServer) processMessage(
|
||||
return s.writeResponse(response, writer)
|
||||
}
|
||||
|
||||
// Handle the message using the wrapped server
|
||||
// Check if this is a response to a sampling request
|
||||
if s.handleSamplingResponse(rawMessage) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this is a tool call that might need sampling (and thus should be processed concurrently)
|
||||
var baseMessage struct {
|
||||
Method string `json:"method"`
|
||||
}
|
||||
if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
|
||||
// Process tool calls concurrently to avoid blocking on sampling requests
|
||||
go func() {
|
||||
response := s.server.HandleMessage(ctx, rawMessage)
|
||||
if response != nil {
|
||||
if err := s.writeResponse(response, writer); err != nil {
|
||||
s.errLogger.Printf("Error writing tool response: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle other messages synchronously
|
||||
response := s.server.HandleMessage(ctx, rawMessage)
|
||||
|
||||
// Only write response if there is one (not for notifications)
|
||||
@@ -269,6 +378,65 @@ func (s *StdioServer) processMessage(
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSamplingResponse checks if the message is a response to a sampling request
|
||||
// and routes it to the appropriate pending request channel.
|
||||
func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
|
||||
return stdioSessionInstance.handleSamplingResponse(rawMessage)
|
||||
}
|
||||
|
||||
// handleSamplingResponse handles incoming sampling responses for this session
|
||||
func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
|
||||
// Try to parse as a JSON-RPC response
|
||||
var response struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.Number `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(rawMessage, &response); err != nil {
|
||||
return false
|
||||
}
|
||||
// Parse the ID as int64
|
||||
idInt64, err := response.ID.Int64()
|
||||
if err != nil || (response.Result == nil && response.Error == nil) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Look for a pending request with this ID
|
||||
s.pendingMu.RLock()
|
||||
responseChan, exists := s.pendingRequests[idInt64]
|
||||
s.pendingMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
} // Parse and send the response
|
||||
samplingResp := &samplingResponse{}
|
||||
|
||||
if response.Error != nil {
|
||||
samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
|
||||
} else {
|
||||
var result mcp.CreateMessageResult
|
||||
if err := json.Unmarshal(response.Result, &result); err != nil {
|
||||
samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
|
||||
} else {
|
||||
samplingResp.result = &result
|
||||
}
|
||||
}
|
||||
|
||||
// Send the response (non-blocking)
|
||||
select {
|
||||
case responseChan <- samplingResp:
|
||||
default:
|
||||
// Channel is full or closed, ignore
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
|
||||
// Returns an error if marshaling or writing fails.
|
||||
func (s *StdioServer) writeResponse(
|
||||
|
||||
4
vendor/github.com/mark3labs/mcp-go/server/streamable_http.go
generated
vendored
4
vendor/github.com/mark3labs/mcp-go/server/streamable_http.go
generated
vendored
@@ -40,9 +40,11 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption {
|
||||
// to StatelessSessionIdManager.
|
||||
func WithStateLess(stateLess bool) StreamableHTTPOption {
|
||||
return func(s *StreamableHTTPServer) {
|
||||
if stateLess {
|
||||
s.sessionIdManager = &StatelessSessionIdManager{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithSessionIdManager sets a custom session id generator for the server.
|
||||
// By default, the server will use SimpleStatefulSessionIdGenerator, which generates
|
||||
@@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
|
||||
2
vendor/modules.txt
vendored
2
vendor/modules.txt
vendored
@@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty
|
||||
# github.com/lucasb-eyer/go-colorful v1.2.0
|
||||
## explicit; go 1.12
|
||||
github.com/lucasb-eyer/go-colorful
|
||||
# github.com/mark3labs/mcp-go v0.32.0
|
||||
# github.com/mark3labs/mcp-go v0.33.0
|
||||
## explicit; go 1.23
|
||||
github.com/mark3labs/mcp-go/client
|
||||
github.com/mark3labs/mcp-go/client/transport
|
||||
|
||||
Reference in New Issue
Block a user