mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
chore(deps): update mcp-go (#155)
* chore(deps): update mcp-go Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com> * fix: vendoring Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com> --------- Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
531b3fd44f
commit
3e820ececc
2
go.mod
2
go.mod
@@ -29,7 +29,7 @@ require (
|
|||||||
github.com/fsnotify/fsnotify v1.8.0
|
github.com/fsnotify/fsnotify v1.8.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/mark3labs/mcp-go v0.32.0
|
github.com/mark3labs/mcp-go v0.33.0
|
||||||
github.com/muesli/termenv v0.16.0
|
github.com/muesli/termenv v0.16.0
|
||||||
github.com/ncruces/go-sqlite3 v0.25.0
|
github.com/ncruces/go-sqlite3 v0.25.0
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -165,8 +165,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
|||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8=
|
github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc=
|
||||||
github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
|
github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||||
|
|||||||
102
vendor/github.com/mark3labs/mcp-go/client/client.go
generated
vendored
102
vendor/github.com/mark3labs/mcp-go/client/client.go
generated
vendored
@@ -22,6 +22,7 @@ type Client struct {
|
|||||||
requestID atomic.Int64
|
requestID atomic.Int64
|
||||||
clientCapabilities mcp.ClientCapabilities
|
clientCapabilities mcp.ClientCapabilities
|
||||||
serverCapabilities mcp.ServerCapabilities
|
serverCapabilities mcp.ServerCapabilities
|
||||||
|
samplingHandler SamplingHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
@@ -33,6 +34,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSamplingHandler sets the sampling handler for the client.
|
||||||
|
// When set, the client will declare sampling capability during initialization.
|
||||||
|
func WithSamplingHandler(handler SamplingHandler) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.samplingHandler = handler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSession assumes a MCP Session has already been initialized
|
||||||
|
func WithSession() ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.initialized = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewClient creates a new MCP client with the given transport.
|
// NewClient creates a new MCP client with the given transport.
|
||||||
// Usage:
|
// Usage:
|
||||||
//
|
//
|
||||||
@@ -71,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
|
|||||||
handler(notification)
|
handler(notification)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Set up request handler for bidirectional communication (e.g., sampling)
|
||||||
|
if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
|
||||||
|
bidirectional.SetRequestHandler(c.handleIncomingRequest)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,6 +149,12 @@ func (c *Client) Initialize(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request mcp.InitializeRequest,
|
request mcp.InitializeRequest,
|
||||||
) (*mcp.InitializeResult, error) {
|
) (*mcp.InitializeResult, error) {
|
||||||
|
// Merge client capabilities with sampling capability if handler is configured
|
||||||
|
capabilities := request.Params.Capabilities
|
||||||
|
if c.samplingHandler != nil {
|
||||||
|
capabilities.Sampling = &struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure we send a params object with all required fields
|
// Ensure we send a params object with all required fields
|
||||||
params := struct {
|
params := struct {
|
||||||
ProtocolVersion string `json:"protocolVersion"`
|
ProtocolVersion string `json:"protocolVersion"`
|
||||||
@@ -135,7 +163,7 @@ func (c *Client) Initialize(
|
|||||||
}{
|
}{
|
||||||
ProtocolVersion: request.Params.ProtocolVersion,
|
ProtocolVersion: request.Params.ProtocolVersion,
|
||||||
ClientInfo: request.Params.ClientInfo,
|
ClientInfo: request.Params.ClientInfo,
|
||||||
Capabilities: request.Params.Capabilities, // Will be empty struct if not set
|
Capabilities: capabilities,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := c.sendRequest(ctx, "initialize", params)
|
response, err := c.sendRequest(ctx, "initialize", params)
|
||||||
@@ -398,6 +426,64 @@ func (c *Client) Complete(
|
|||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleIncomingRequest processes incoming requests from the server.
|
||||||
|
// This is the main entry point for server-to-client requests like sampling.
|
||||||
|
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
|
||||||
|
switch request.Method {
|
||||||
|
case string(mcp.MethodSamplingCreateMessage):
|
||||||
|
return c.handleSamplingRequestTransport(ctx, request)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSamplingRequestTransport handles sampling requests at the transport level.
|
||||||
|
func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
|
||||||
|
if c.samplingHandler == nil {
|
||||||
|
return nil, fmt.Errorf("no sampling handler configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the request parameters
|
||||||
|
var params mcp.CreateMessageParams
|
||||||
|
if request.Params != nil {
|
||||||
|
paramsBytes, err := json.Marshal(request.Params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal params: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(paramsBytes, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the MCP request
|
||||||
|
mcpRequest := mcp.CreateMessageRequest{
|
||||||
|
Request: mcp.Request{
|
||||||
|
Method: string(mcp.MethodSamplingCreateMessage),
|
||||||
|
},
|
||||||
|
CreateMessageParams: params,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the sampling handler
|
||||||
|
result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the result
|
||||||
|
resultBytes, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal result: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the transport response
|
||||||
|
response := &transport.JSONRPCResponse{
|
||||||
|
JSONRPC: mcp.JSONRPC_VERSION,
|
||||||
|
ID: request.ID,
|
||||||
|
Result: json.RawMessage(resultBytes),
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
func listByPage[T any](
|
func listByPage[T any](
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *Client,
|
client *Client,
|
||||||
@@ -432,3 +518,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
|
|||||||
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
|
func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
|
||||||
return c.clientCapabilities
|
return c.clientCapabilities
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSessionId returns the session ID of the transport.
|
||||||
|
// If the transport does not support sessions, it returns an empty string.
|
||||||
|
func (c *Client) GetSessionId() string {
|
||||||
|
if c.transport == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.transport.GetSessionId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInitialized returns true if the client has been initialized.
|
||||||
|
func (c *Client) IsInitialized() bool {
|
||||||
|
return c.initialized
|
||||||
|
}
|
||||||
|
|||||||
7
vendor/github.com/mark3labs/mcp-go/client/http.go
generated
vendored
7
vendor/github.com/mark3labs/mcp-go/client/http.go
generated
vendored
@@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
|
return nil, fmt.Errorf("failed to create SSE transport: %w", err)
|
||||||
}
|
}
|
||||||
return NewClient(trans), nil
|
clientOptions := make([]ClientOption, 0)
|
||||||
|
sessionID := trans.GetSessionId()
|
||||||
|
if sessionID != "" {
|
||||||
|
clientOptions = append(clientOptions, WithSession())
|
||||||
|
}
|
||||||
|
return NewClient(trans, clientOptions...), nil
|
||||||
}
|
}
|
||||||
|
|||||||
20
vendor/github.com/mark3labs/mcp-go/client/sampling.go
generated
vendored
Normal file
20
vendor/github.com/mark3labs/mcp-go/client/sampling.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SamplingHandler defines the interface for handling sampling requests from servers.
|
||||||
|
// Clients can implement this interface to provide LLM sampling capabilities to servers.
|
||||||
|
type SamplingHandler interface {
|
||||||
|
// CreateMessage handles a sampling request from the server and returns the generated message.
|
||||||
|
// The implementation should:
|
||||||
|
// 1. Validate the request parameters
|
||||||
|
// 2. Optionally prompt the user for approval (human-in-the-loop)
|
||||||
|
// 3. Select an appropriate model based on preferences
|
||||||
|
// 4. Generate the response using the selected model
|
||||||
|
// 5. Return the result with model information and stop reason
|
||||||
|
CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
|
||||||
|
}
|
||||||
22
vendor/github.com/mark3labs/mcp-go/client/stdio.go
generated
vendored
22
vendor/github.com/mark3labs/mcp-go/client/stdio.go
generated
vendored
@@ -19,10 +19,26 @@ func NewStdioMCPClient(
|
|||||||
env []string,
|
env []string,
|
||||||
args ...string,
|
args ...string,
|
||||||
) (*Client, error) {
|
) (*Client, error) {
|
||||||
|
return NewStdioMCPClientWithOptions(command, env, args)
|
||||||
|
}
|
||||||
|
|
||||||
stdioTransport := transport.NewStdio(command, env, args...)
|
// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess.
|
||||||
err := stdioTransport.Start(context.Background())
|
// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
|
||||||
if err != nil {
|
// Optional configuration functions can be provided to customize the transport before it starts,
|
||||||
|
// such as setting a custom command function.
|
||||||
|
//
|
||||||
|
// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport.
|
||||||
|
// Don't call the Start method manually.
|
||||||
|
// This is for backward compatibility.
|
||||||
|
func NewStdioMCPClientWithOptions(
|
||||||
|
command string,
|
||||||
|
env []string,
|
||||||
|
args []string,
|
||||||
|
opts ...transport.StdioOption,
|
||||||
|
) (*Client, error) {
|
||||||
|
stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...)
|
||||||
|
|
||||||
|
if err := stdioTransport.Start(context.Background()); err != nil {
|
||||||
return nil, fmt.Errorf("failed to start stdio transport: %w", err)
|
return nil, fmt.Errorf("failed to start stdio transport: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
4
vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go
generated
vendored
4
vendor/github.com/mark3labs/mcp-go/client/transport/inprocess.go
generated
vendored
@@ -68,3 +68,7 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc
|
|||||||
func (*InProcessTransport) Close() error {
|
func (*InProcessTransport) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *InProcessTransport) GetSessionId() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
20
vendor/github.com/mark3labs/mcp-go/client/transport/interface.go
generated
vendored
20
vendor/github.com/mark3labs/mcp-go/client/transport/interface.go
generated
vendored
@@ -29,6 +29,22 @@ type Interface interface {
|
|||||||
|
|
||||||
// Close the connection.
|
// Close the connection.
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
|
// GetSessionId returns the session ID of the transport.
|
||||||
|
GetSessionId() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestHandler defines a function that handles incoming requests from the server.
|
||||||
|
type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error)
|
||||||
|
|
||||||
|
// BidirectionalInterface extends Interface to support incoming requests from the server.
|
||||||
|
// This is used for features like sampling where the server can send requests to the client.
|
||||||
|
type BidirectionalInterface interface {
|
||||||
|
Interface
|
||||||
|
|
||||||
|
// SetRequestHandler sets the handler for incoming requests from the server.
|
||||||
|
// The handler should process the request and return a response.
|
||||||
|
SetRequestHandler(handler RequestHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
type JSONRPCRequest struct {
|
type JSONRPCRequest struct {
|
||||||
@@ -41,10 +57,10 @@ type JSONRPCRequest struct {
|
|||||||
type JSONRPCResponse struct {
|
type JSONRPCResponse struct {
|
||||||
JSONRPC string `json:"jsonrpc"`
|
JSONRPC string `json:"jsonrpc"`
|
||||||
ID mcp.RequestId `json:"id"`
|
ID mcp.RequestId `json:"id"`
|
||||||
Result json.RawMessage `json:"result"`
|
Result json.RawMessage `json:"result,omitempty"`
|
||||||
Error *struct {
|
Error *struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Data json.RawMessage `json:"data"`
|
Data json.RawMessage `json:"data"`
|
||||||
} `json:"error"`
|
} `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
6
vendor/github.com/mark3labs/mcp-go/client/transport/sse.go
generated
vendored
6
vendor/github.com/mark3labs/mcp-go/client/transport/sse.go
generated
vendored
@@ -428,6 +428,12 @@ func (c *SSE) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSessionId returns the session ID of the transport.
|
||||||
|
// Since SSE does not maintain a session ID, it returns an empty string.
|
||||||
|
func (c *SSE) GetSessionId() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
|
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
|
||||||
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
|
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
|
||||||
if c.endpoint == nil {
|
if c.endpoint == nil {
|
||||||
|
|||||||
204
vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go
generated
vendored
204
vendor/github.com/mark3labs/mcp-go/client/transport/stdio.go
generated
vendored
@@ -23,6 +23,7 @@ type Stdio struct {
|
|||||||
env []string
|
env []string
|
||||||
|
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
cmdFunc CommandFunc
|
||||||
stdin io.WriteCloser
|
stdin io.WriteCloser
|
||||||
stdout *bufio.Reader
|
stdout *bufio.Reader
|
||||||
stderr io.ReadCloser
|
stderr io.ReadCloser
|
||||||
@@ -31,6 +32,28 @@ type Stdio struct {
|
|||||||
done chan struct{}
|
done chan struct{}
|
||||||
onNotification func(mcp.JSONRPCNotification)
|
onNotification func(mcp.JSONRPCNotification)
|
||||||
notifyMu sync.RWMutex
|
notifyMu sync.RWMutex
|
||||||
|
onRequest RequestHandler
|
||||||
|
requestMu sync.RWMutex
|
||||||
|
ctx context.Context
|
||||||
|
ctxMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// StdioOption defines a function that configures a Stdio transport instance.
|
||||||
|
// Options can be used to customize the behavior of the transport before it starts,
|
||||||
|
// such as setting a custom command function.
|
||||||
|
type StdioOption func(*Stdio)
|
||||||
|
|
||||||
|
// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess.
|
||||||
|
// It can be used to apply sandboxing, custom environment control, working directories, etc.
|
||||||
|
type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error)
|
||||||
|
|
||||||
|
// WithCommandFunc sets a custom command factory function for the stdio transport.
|
||||||
|
// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess,
|
||||||
|
// allowing control over attributes like environment, working directory, and system-level sandboxing.
|
||||||
|
func WithCommandFunc(f CommandFunc) StdioOption {
|
||||||
|
return func(s *Stdio) {
|
||||||
|
s.cmdFunc = f
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIO returns a new stdio-based transport using existing input, output, and
|
// NewIO returns a new stdio-based transport using existing input, output, and
|
||||||
@@ -44,6 +67,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
|
|||||||
|
|
||||||
responses: make(map[string]chan *JSONRPCResponse),
|
responses: make(map[string]chan *JSONRPCResponse),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
|
ctx: context.Background(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,20 +79,43 @@ func NewStdio(
|
|||||||
env []string,
|
env []string,
|
||||||
args ...string,
|
args ...string,
|
||||||
) *Stdio {
|
) *Stdio {
|
||||||
|
return NewStdioWithOptions(command, env, args)
|
||||||
|
}
|
||||||
|
|
||||||
client := &Stdio{
|
// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess.
|
||||||
|
// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
|
||||||
|
// Returns an error if the subprocess cannot be started or the pipes cannot be created.
|
||||||
|
// Optional configuration functions can be provided to customize the transport before it starts,
|
||||||
|
// such as setting a custom command factory.
|
||||||
|
func NewStdioWithOptions(
|
||||||
|
command string,
|
||||||
|
env []string,
|
||||||
|
args []string,
|
||||||
|
opts ...StdioOption,
|
||||||
|
) *Stdio {
|
||||||
|
s := &Stdio{
|
||||||
command: command,
|
command: command,
|
||||||
args: args,
|
args: args,
|
||||||
env: env,
|
env: env,
|
||||||
|
|
||||||
responses: make(map[string]chan *JSONRPCResponse),
|
responses: make(map[string]chan *JSONRPCResponse),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
|
ctx: context.Background(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return client
|
for _, opt := range opts {
|
||||||
|
opt(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Stdio) Start(ctx context.Context) error {
|
func (c *Stdio) Start(ctx context.Context) error {
|
||||||
|
// Store the context for use in request handling
|
||||||
|
c.ctxMu.Lock()
|
||||||
|
c.ctx = ctx
|
||||||
|
c.ctxMu.Unlock()
|
||||||
|
|
||||||
if err := c.spawnCommand(ctx); err != nil {
|
if err := c.spawnCommand(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -83,18 +130,25 @@ func (c *Stdio) Start(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// spawnCommand spawns a new process running c.command.
|
// spawnCommand spawns a new process running the configured command, args, and env.
|
||||||
|
// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
|
||||||
|
// otherwise, the default behavior uses exec.CommandContext with the merged environment.
|
||||||
|
// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
|
||||||
func (c *Stdio) spawnCommand(ctx context.Context) error {
|
func (c *Stdio) spawnCommand(ctx context.Context) error {
|
||||||
if c.command == "" {
|
if c.command == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.CommandContext(ctx, c.command, c.args...)
|
var cmd *exec.Cmd
|
||||||
|
var err error
|
||||||
|
|
||||||
mergedEnv := os.Environ()
|
// Standard behavior if no command func present.
|
||||||
mergedEnv = append(mergedEnv, c.env...)
|
if c.cmdFunc == nil {
|
||||||
|
cmd = exec.CommandContext(ctx, c.command, c.args...)
|
||||||
cmd.Env = mergedEnv
|
cmd.Env = append(os.Environ(), c.env...)
|
||||||
|
} else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
stdin, err := cmd.StdinPipe()
|
stdin, err := cmd.StdinPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -148,6 +202,12 @@ func (c *Stdio) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSessionId returns the session ID of the transport.
|
||||||
|
// Since stdio does not maintain a session ID, it returns an empty string.
|
||||||
|
func (c *Stdio) GetSessionId() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// SetNotificationHandler sets the handler function to be called when a notification is received.
|
// SetNotificationHandler sets the handler function to be called when a notification is received.
|
||||||
// Only one handler can be set at a time; setting a new one replaces the previous handler.
|
// Only one handler can be set at a time; setting a new one replaces the previous handler.
|
||||||
func (c *Stdio) SetNotificationHandler(
|
func (c *Stdio) SetNotificationHandler(
|
||||||
@@ -158,6 +218,14 @@ func (c *Stdio) SetNotificationHandler(
|
|||||||
c.onNotification = handler
|
c.onNotification = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRequestHandler sets the handler function to be called when a request is received from the server.
|
||||||
|
// This enables bidirectional communication for features like sampling.
|
||||||
|
func (c *Stdio) SetRequestHandler(handler RequestHandler) {
|
||||||
|
c.requestMu.Lock()
|
||||||
|
defer c.requestMu.Unlock()
|
||||||
|
c.onRequest = handler
|
||||||
|
}
|
||||||
|
|
||||||
// readResponses continuously reads and processes responses from the server's stdout.
|
// readResponses continuously reads and processes responses from the server's stdout.
|
||||||
// It handles both responses to requests and notifications, routing them appropriately.
|
// It handles both responses to requests and notifications, routing them appropriately.
|
||||||
// Runs until the done channel is closed or an error occurs reading from stdout.
|
// Runs until the done channel is closed or an error occurs reading from stdout.
|
||||||
@@ -175,13 +243,18 @@ func (c *Stdio) readResponses() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var baseMessage JSONRPCResponse
|
// First try to parse as a generic message to check for ID field
|
||||||
|
var baseMessage struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
ID *mcp.RequestId `json:"id,omitempty"`
|
||||||
|
Method string `json:"method,omitempty"`
|
||||||
|
}
|
||||||
if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
|
if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle notification
|
// If it has a method but no ID, it's a notification
|
||||||
if baseMessage.ID.IsNil() {
|
if baseMessage.Method != "" && baseMessage.ID == nil {
|
||||||
var notification mcp.JSONRPCNotification
|
var notification mcp.JSONRPCNotification
|
||||||
if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
|
if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
|
||||||
continue
|
continue
|
||||||
@@ -194,15 +267,30 @@ func (c *Stdio) readResponses() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If it has a method and an ID, it's an incoming request
|
||||||
|
if baseMessage.Method != "" && baseMessage.ID != nil {
|
||||||
|
var request JSONRPCRequest
|
||||||
|
if err := json.Unmarshal([]byte(line), &request); err == nil {
|
||||||
|
c.handleIncomingRequest(request)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, it's a response to our request
|
||||||
|
var response JSONRPCResponse
|
||||||
|
if err := json.Unmarshal([]byte(line), &response); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Create string key for map lookup
|
// Create string key for map lookup
|
||||||
idKey := baseMessage.ID.String()
|
idKey := response.ID.String()
|
||||||
|
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
ch, exists := c.responses[idKey]
|
ch, exists := c.responses[idKey]
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
ch <- &baseMessage
|
ch <- &response
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
delete(c.responses, idKey)
|
delete(c.responses, idKey)
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
@@ -281,6 +369,96 @@ func (c *Stdio) SendNotification(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleIncomingRequest processes incoming requests from the server.
|
||||||
|
// It calls the registered request handler and sends the response back to the server.
|
||||||
|
func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
|
||||||
|
c.requestMu.RLock()
|
||||||
|
handler := c.onRequest
|
||||||
|
c.requestMu.RUnlock()
|
||||||
|
|
||||||
|
if handler == nil {
|
||||||
|
// Send error response if no handler is configured
|
||||||
|
errorResponse := JSONRPCResponse{
|
||||||
|
JSONRPC: mcp.JSONRPC_VERSION,
|
||||||
|
ID: request.ID,
|
||||||
|
Error: &struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
}{
|
||||||
|
Code: mcp.METHOD_NOT_FOUND,
|
||||||
|
Message: "No request handler configured",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.sendResponse(errorResponse)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the request in a goroutine to avoid blocking
|
||||||
|
go func() {
|
||||||
|
c.ctxMu.RLock()
|
||||||
|
ctx := c.ctx
|
||||||
|
c.ctxMu.RUnlock()
|
||||||
|
|
||||||
|
// Check if context is already cancelled before processing
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
errorResponse := JSONRPCResponse{
|
||||||
|
JSONRPC: mcp.JSONRPC_VERSION,
|
||||||
|
ID: request.ID,
|
||||||
|
Error: &struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
}{
|
||||||
|
Code: mcp.INTERNAL_ERROR,
|
||||||
|
Message: ctx.Err().Error(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.sendResponse(errorResponse)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := handler(ctx, request)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
errorResponse := JSONRPCResponse{
|
||||||
|
JSONRPC: mcp.JSONRPC_VERSION,
|
||||||
|
ID: request.ID,
|
||||||
|
Error: &struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
}{
|
||||||
|
Code: mcp.INTERNAL_ERROR,
|
||||||
|
Message: err.Error(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.sendResponse(errorResponse)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if response != nil {
|
||||||
|
c.sendResponse(*response)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendResponse sends a response back to the server.
|
||||||
|
func (c *Stdio) sendResponse(response JSONRPCResponse) {
|
||||||
|
responseBytes, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error marshaling response: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
responseBytes = append(responseBytes, '\n')
|
||||||
|
|
||||||
|
if _, err := c.stdin.Write(responseBytes); err != nil {
|
||||||
|
fmt.Printf("Error writing response: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stderr returns a reader for the stderr output of the subprocess.
|
// Stderr returns a reader for the stderr output of the subprocess.
|
||||||
// This can be used to capture error messages or logs from the subprocess.
|
// This can be used to capture error messages or logs from the subprocess.
|
||||||
func (c *Stdio) Stderr() io.Reader {
|
func (c *Stdio) Stderr() io.Reader {
|
||||||
|
|||||||
371
vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go
generated
vendored
371
vendor/github.com/mark3labs/mcp-go/client/transport/streamable_http.go
generated
vendored
@@ -17,10 +17,24 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mark3labs/mcp-go/mcp"
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
"github.com/mark3labs/mcp-go/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StreamableHTTPCOption func(*StreamableHTTP)
|
type StreamableHTTPCOption func(*StreamableHTTP)
|
||||||
|
|
||||||
|
// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
|
||||||
|
// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
|
||||||
|
// you should enable this option.
|
||||||
|
//
|
||||||
|
// It will establish a standalone long-live GET HTTP connection to the server.
|
||||||
|
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
|
||||||
|
// NOTICE: Even enabled, the server may not support this feature.
|
||||||
|
func WithContinuousListening() StreamableHTTPCOption {
|
||||||
|
return func(sc *StreamableHTTP) {
|
||||||
|
sc.getListeningEnabled = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
|
// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
|
||||||
func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
|
func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
|
||||||
return func(sc *StreamableHTTP) {
|
return func(sc *StreamableHTTP) {
|
||||||
@@ -54,6 +68,19 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithLogger(logger util.Logger) StreamableHTTPCOption {
|
||||||
|
return func(sc *StreamableHTTP) {
|
||||||
|
sc.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSession creates a client with a pre-configured session
|
||||||
|
func WithSession(sessionID string) StreamableHTTPCOption {
|
||||||
|
return func(sc *StreamableHTTP) {
|
||||||
|
sc.sessionID.Store(sessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// StreamableHTTP implements Streamable HTTP transport.
|
// StreamableHTTP implements Streamable HTTP transport.
|
||||||
//
|
//
|
||||||
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
|
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
|
||||||
@@ -64,19 +91,22 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
|
|||||||
//
|
//
|
||||||
// The current implementation does not support the following features:
|
// The current implementation does not support the following features:
|
||||||
// - batching
|
// - batching
|
||||||
// - continuously listening for server notifications when no request is in flight
|
|
||||||
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
|
|
||||||
// - resuming stream
|
// - resuming stream
|
||||||
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
|
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
|
||||||
// - server -> client request
|
// - server -> client request
|
||||||
type StreamableHTTP struct {
|
type StreamableHTTP struct {
|
||||||
serverURL *url.URL
|
serverURL *url.URL
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
headerFunc HTTPHeaderFunc
|
headerFunc HTTPHeaderFunc
|
||||||
|
logger util.Logger
|
||||||
|
getListeningEnabled bool
|
||||||
|
|
||||||
sessionID atomic.Value // string
|
sessionID atomic.Value // string
|
||||||
|
|
||||||
|
initialized chan struct{}
|
||||||
|
initializedOnce sync.Once
|
||||||
|
|
||||||
notificationHandler func(mcp.JSONRPCNotification)
|
notificationHandler func(mcp.JSONRPCNotification)
|
||||||
notifyMu sync.RWMutex
|
notifyMu sync.RWMutex
|
||||||
|
|
||||||
@@ -95,15 +125,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
|
|||||||
}
|
}
|
||||||
|
|
||||||
smc := &StreamableHTTP{
|
smc := &StreamableHTTP{
|
||||||
serverURL: parsedURL,
|
serverURL: parsedURL,
|
||||||
httpClient: &http.Client{},
|
httpClient: &http.Client{},
|
||||||
headers: make(map[string]string),
|
headers: make(map[string]string),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
|
logger: util.DefaultLogger(),
|
||||||
|
initialized: make(chan struct{}),
|
||||||
}
|
}
|
||||||
smc.sessionID.Store("") // set initial value to simplify later usage
|
smc.sessionID.Store("") // set initial value to simplify later usage
|
||||||
|
|
||||||
for _, opt := range options {
|
for _, opt := range options {
|
||||||
opt(smc)
|
if opt != nil {
|
||||||
|
opt(smc)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If OAuth is configured, set the base URL for metadata discovery
|
// If OAuth is configured, set the base URL for metadata discovery
|
||||||
@@ -118,7 +152,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
|
|||||||
|
|
||||||
// Start initiates the HTTP connection to the server.
|
// Start initiates the HTTP connection to the server.
|
||||||
func (c *StreamableHTTP) Start(ctx context.Context) error {
|
func (c *StreamableHTTP) Start(ctx context.Context) error {
|
||||||
// For Streamable HTTP, we don't need to establish a persistent connection
|
// For Streamable HTTP, we don't need to establish a persistent connection by default
|
||||||
|
if c.getListeningEnabled {
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-c.initialized:
|
||||||
|
ctx, cancel := c.contextAwareOfClientClose(ctx)
|
||||||
|
defer cancel()
|
||||||
|
c.listenForever(ctx)
|
||||||
|
case <-c.closed:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,13 +189,13 @@ func (c *StreamableHTTP) Close() error {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to create close request\n: %v", err)
|
c.logger.Errorf("failed to create close request: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header.Set(headerKeySessionID, sessionId)
|
req.Header.Set(headerKeySessionID, sessionId)
|
||||||
res, err := c.httpClient.Do(req)
|
res, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to send close request\n: %v", err)
|
c.logger.Errorf("failed to send close request: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
@@ -185,37 +232,103 @@ func (c *StreamableHTTP) SendRequest(
|
|||||||
request JSONRPCRequest,
|
request JSONRPCRequest,
|
||||||
) (*JSONRPCResponse, error) {
|
) (*JSONRPCResponse, error) {
|
||||||
|
|
||||||
// Create a combined context that could be canceled when the client is closed
|
|
||||||
newCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-c.closed:
|
|
||||||
cancel()
|
|
||||||
case <-newCtx.Done():
|
|
||||||
// The original context was canceled, no need to do anything
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
ctx = newCtx
|
|
||||||
|
|
||||||
// Marshal request
|
// Marshal request
|
||||||
requestBody, err := json.Marshal(request)
|
requestBody, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := c.contextAwareOfClientClose(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
|
||||||
|
// If the request is initialize, should not return a SessionTerminated error
|
||||||
|
// It should be a genuine endpoint-routing issue.
|
||||||
|
// ( Fall through to return StatusCode checking. )
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check if we got an error response
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
|
||||||
|
|
||||||
|
// Handle OAuth unauthorized error
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
|
||||||
|
return nil, &OAuthAuthorizationRequiredError{
|
||||||
|
Handler: c.oauthHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle error response
|
||||||
|
var errResponse JSONRPCResponse
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if err := json.Unmarshal(body, &errResponse); err == nil {
|
||||||
|
return &errResponse, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Method == string(mcp.MethodInitialize) {
|
||||||
|
// saved the received session ID in the response
|
||||||
|
// empty session ID is allowed
|
||||||
|
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
|
||||||
|
c.sessionID.Store(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.initializedOnce.Do(func() {
|
||||||
|
close(c.initialized)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different response types
|
||||||
|
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||||
|
switch mediaType {
|
||||||
|
case "application/json":
|
||||||
|
// Single response
|
||||||
|
var response JSONRPCResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// should not be a notification
|
||||||
|
if response.ID.IsNil() {
|
||||||
|
return nil, fmt.Errorf("response should contain RPC id: %v", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, nil
|
||||||
|
|
||||||
|
case "text/event-stream":
|
||||||
|
// Server is using SSE for streaming responses
|
||||||
|
return c.handleSSEResponse(ctx, resp.Body, false)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StreamableHTTP) sendHTTP(
|
||||||
|
ctx context.Context,
|
||||||
|
method string,
|
||||||
|
body io.Reader,
|
||||||
|
acceptType string,
|
||||||
|
) (resp *http.Response, err error) {
|
||||||
|
|
||||||
// Create HTTP request
|
// Create HTTP request
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
|
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Accept", "application/json, text/event-stream")
|
req.Header.Set("Accept", acceptType)
|
||||||
sessionID := c.sessionID.Load()
|
sessionID := c.sessionID.Load().(string)
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
req.Header.Set(headerKeySessionID, sessionID.(string))
|
req.Header.Set(headerKeySessionID, sessionID)
|
||||||
}
|
}
|
||||||
for k, v := range c.headers {
|
for k, v := range c.headers {
|
||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
@@ -243,73 +356,24 @@ func (c *StreamableHTTP) SendRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err = c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Check if we got an error response
|
// universal handling for session terminated
|
||||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
// handle session closed
|
c.sessionID.CompareAndSwap(sessionID, "")
|
||||||
if resp.StatusCode == http.StatusNotFound {
|
return nil, ErrSessionTerminated
|
||||||
c.sessionID.CompareAndSwap(sessionID, "")
|
|
||||||
return nil, fmt.Errorf("session terminated (404). need to re-initialize")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle OAuth unauthorized error
|
|
||||||
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
|
|
||||||
return nil, &OAuthAuthorizationRequiredError{
|
|
||||||
Handler: c.oauthHandler,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle error response
|
|
||||||
var errResponse JSONRPCResponse
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
if err := json.Unmarshal(body, &errResponse); err == nil {
|
|
||||||
return &errResponse, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Method == string(mcp.MethodInitialize) {
|
return resp, nil
|
||||||
// saved the received session ID in the response
|
|
||||||
// empty session ID is allowed
|
|
||||||
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
|
|
||||||
c.sessionID.Store(sessionID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle different response types
|
|
||||||
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
|
||||||
switch mediaType {
|
|
||||||
case "application/json":
|
|
||||||
// Single response
|
|
||||||
var response JSONRPCResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// should not be a notification
|
|
||||||
if response.ID.IsNil() {
|
|
||||||
return nil, fmt.Errorf("response should contain RPC id: %v", response)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &response, nil
|
|
||||||
|
|
||||||
case "text/event-stream":
|
|
||||||
// Server is using SSE for streaming responses
|
|
||||||
return c.handleSSEResponse(ctx, resp.Body)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleSSEResponse processes an SSE stream for a specific request.
|
// handleSSEResponse processes an SSE stream for a specific request.
|
||||||
// It returns the final result for the request once received, or an error.
|
// It returns the final result for the request once received, or an error.
|
||||||
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
|
// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
|
||||||
|
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {
|
||||||
|
|
||||||
// Create a channel for this specific request
|
// Create a channel for this specific request
|
||||||
responseChan := make(chan *JSONRPCResponse, 1)
|
responseChan := make(chan *JSONRPCResponse, 1)
|
||||||
@@ -328,7 +392,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
|
|||||||
|
|
||||||
var message JSONRPCResponse
|
var message JSONRPCResponse
|
||||||
if err := json.Unmarshal([]byte(data), &message); err != nil {
|
if err := json.Unmarshal([]byte(data), &message); err != nil {
|
||||||
fmt.Printf("failed to unmarshal message: %v\n", err)
|
c.logger.Errorf("failed to unmarshal message: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,7 +400,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
|
|||||||
if message.ID.IsNil() {
|
if message.ID.IsNil() {
|
||||||
var notification mcp.JSONRPCNotification
|
var notification mcp.JSONRPCNotification
|
||||||
if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
|
if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
|
||||||
fmt.Printf("failed to unmarshal notification: %v\n", err)
|
c.logger.Errorf("failed to unmarshal notification: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.notifyMu.RLock()
|
c.notifyMu.RLock()
|
||||||
@@ -347,7 +411,9 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
responseChan <- &message
|
if !ignoreResponse {
|
||||||
|
responseChan <- &message
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -393,7 +459,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
fmt.Printf("SSE stream error: %v\n", err)
|
c.logger.Errorf("SSE stream error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -432,44 +498,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create HTTP request
|
// Create HTTP request
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
|
ctx, cancel := c.contextAwareOfClientClose(ctx)
|
||||||
if err != nil {
|
defer cancel()
|
||||||
return fmt.Errorf("failed to create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set headers
|
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json, text/event-stream")
|
|
||||||
if sessionID := c.sessionID.Load(); sessionID != "" {
|
|
||||||
req.Header.Set(headerKeySessionID, sessionID.(string))
|
|
||||||
}
|
|
||||||
for k, v := range c.headers {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add OAuth authorization if configured
|
|
||||||
if c.oauthHandler != nil {
|
|
||||||
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// If we get an authorization error, return a specific error that can be handled by the client
|
|
||||||
if errors.Is(err, ErrOAuthAuthorizationRequired) {
|
|
||||||
return &OAuthAuthorizationRequiredError{
|
|
||||||
Handler: c.oauthHandler,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to get authorization header: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", authHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.headerFunc != nil {
|
|
||||||
for k, v := range c.headerFunc(ctx) {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send request
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send request: %w", err)
|
return fmt.Errorf("failed to send request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -513,3 +545,84 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
|
|||||||
func (c *StreamableHTTP) IsOAuthEnabled() bool {
|
func (c *StreamableHTTP) IsOAuthEnabled() bool {
|
||||||
return c.oauthHandler != nil
|
return c.oauthHandler != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *StreamableHTTP) listenForever(ctx context.Context) {
|
||||||
|
c.logger.Infof("listening to server forever")
|
||||||
|
for {
|
||||||
|
err := c.createGETConnectionToServer(ctx)
|
||||||
|
if errors.Is(err, ErrGetMethodNotAllowed) {
|
||||||
|
// server does not support listening
|
||||||
|
c.logger.Errorf("server does not support listening")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(retryInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
|
||||||
|
ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
|
||||||
|
|
||||||
|
retryInterval = 1 * time.Second // a variable is convenient for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
|
||||||
|
|
||||||
|
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check if we got an error response
|
||||||
|
if resp.StatusCode == http.StatusMethodNotAllowed {
|
||||||
|
return ErrGetMethodNotAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle SSE response
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != "text/event-stream" {
|
||||||
|
return fmt.Errorf("unexpected content type: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When ignoreResponse is true, the function will never return expect context is done.
|
||||||
|
// NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response
|
||||||
|
// messages. To be more compatible, we should handle this response, however, as the transport layer is message-based,
|
||||||
|
// currently, there is no convenient way to handle this response.
|
||||||
|
// So we ignore the response here. It's not a bug, but may be not compatible with other SDKs.
|
||||||
|
_, err = c.handleSSEResponse(ctx, resp.Body, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to handle SSE response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
newCtx, cancel := context.WithCancel(ctx)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
cancel()
|
||||||
|
case <-newCtx.Done():
|
||||||
|
// The original context was canceled
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return newCtx, cancel
|
||||||
|
}
|
||||||
|
|||||||
106
vendor/github.com/mark3labs/mcp-go/mcp/tools.go
generated
vendored
106
vendor/github.com/mark3labs/mcp-go/mcp/tools.go
generated
vendored
@@ -945,7 +945,20 @@ func PropertyNames(schema map[string]any) PropertyOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Items defines the schema for array items
|
// Items defines the schema for array items.
|
||||||
|
// Accepts any schema definition for maximum flexibility.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// Items(map[string]any{
|
||||||
|
// "type": "object",
|
||||||
|
// "properties": map[string]any{
|
||||||
|
// "name": map[string]any{"type": "string"},
|
||||||
|
// "age": map[string]any{"type": "number"},
|
||||||
|
// },
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead.
|
||||||
func Items(schema any) PropertyOption {
|
func Items(schema any) PropertyOption {
|
||||||
return func(schemaMap map[string]any) {
|
return func(schemaMap map[string]any) {
|
||||||
schemaMap["items"] = schema
|
schemaMap["items"] = schema
|
||||||
@@ -972,3 +985,94 @@ func UniqueItems(unique bool) PropertyOption {
|
|||||||
schema["uniqueItems"] = unique
|
schema["uniqueItems"] = unique
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithStringItems configures an array's items to be of type string.
|
||||||
|
//
|
||||||
|
// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern()
|
||||||
|
// Note: Options like Required() are not valid for item schemas and will be ignored.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
//
|
||||||
|
// mcp.WithArray("tags", mcp.WithStringItems())
|
||||||
|
// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue")))
|
||||||
|
// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50)))
|
||||||
|
//
|
||||||
|
// Limitations: Only supports simple string arrays. Use Items() for complex objects.
|
||||||
|
func WithStringItems(opts ...PropertyOption) PropertyOption {
|
||||||
|
return func(schema map[string]any) {
|
||||||
|
itemSchema := map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(itemSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema["items"] = itemSchema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStringEnumItems configures an array's items to be of type string with a specified enum.
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"}))
|
||||||
|
//
|
||||||
|
// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility.
|
||||||
|
func WithStringEnumItems(values []string) PropertyOption {
|
||||||
|
return func(schema map[string]any) {
|
||||||
|
schema["items"] = map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"enum": values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithNumberItems configures an array's items to be of type number.
|
||||||
|
//
|
||||||
|
// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf()
|
||||||
|
// Note: Options like Required() are not valid for item schemas and will be ignored.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
//
|
||||||
|
// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100)))
|
||||||
|
// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0)))
|
||||||
|
//
|
||||||
|
// Limitations: Only supports simple number arrays. Use Items() for complex objects.
|
||||||
|
func WithNumberItems(opts ...PropertyOption) PropertyOption {
|
||||||
|
return func(schema map[string]any) {
|
||||||
|
itemSchema := map[string]any{
|
||||||
|
"type": "number",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(itemSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema["items"] = itemSchema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBooleanItems configures an array's items to be of type boolean.
|
||||||
|
//
|
||||||
|
// Supported options: Description(), DefaultBool()
|
||||||
|
// Note: Options like Required() are not valid for item schemas and will be ignored.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
//
|
||||||
|
// mcp.WithArray("flags", mcp.WithBooleanItems())
|
||||||
|
// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions")))
|
||||||
|
//
|
||||||
|
// Limitations: Only supports simple boolean arrays. Use Items() for complex objects.
|
||||||
|
func WithBooleanItems(opts ...PropertyOption) PropertyOption {
|
||||||
|
return func(schema map[string]any) {
|
||||||
|
itemSchema := map[string]any{
|
||||||
|
"type": "boolean",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(itemSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema["items"] = itemSchema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
21
vendor/github.com/mark3labs/mcp-go/mcp/types.go
generated
vendored
21
vendor/github.com/mark3labs/mcp-go/mcp/types.go
generated
vendored
@@ -763,6 +763,11 @@ const (
|
|||||||
|
|
||||||
/* Sampling */
|
/* Sampling */
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MethodSamplingCreateMessage allows servers to request LLM completions from clients
|
||||||
|
MethodSamplingCreateMessage MCPMethod = "sampling/createMessage"
|
||||||
|
)
|
||||||
|
|
||||||
// CreateMessageRequest is a request from the server to sample an LLM via the
|
// CreateMessageRequest is a request from the server to sample an LLM via the
|
||||||
// client. The client has full discretion over which model to select. The client
|
// client. The client has full discretion over which model to select. The client
|
||||||
// should also inform the user before beginning sampling, to allow them to inspect
|
// should also inform the user before beginning sampling, to allow them to inspect
|
||||||
@@ -865,6 +870,22 @@ type AudioContent struct {
|
|||||||
|
|
||||||
func (AudioContent) isContent() {}
|
func (AudioContent) isContent() {}
|
||||||
|
|
||||||
|
// ResourceLink represents a link to a resource that the client can access.
|
||||||
|
type ResourceLink struct {
|
||||||
|
Annotated
|
||||||
|
Type string `json:"type"` // Must be "resource_link"
|
||||||
|
// The URI of the resource.
|
||||||
|
URI string `json:"uri"`
|
||||||
|
// The name of the resource.
|
||||||
|
Name string `json:"name"`
|
||||||
|
// The description of the resource.
|
||||||
|
Description string `json:"description"`
|
||||||
|
// The MIME type of the resource.
|
||||||
|
MIMEType string `json:"mimeType"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ResourceLink) isContent() {}
|
||||||
|
|
||||||
// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
|
// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
|
||||||
//
|
//
|
||||||
// It is up to the client how best to render embedded resources for the
|
// It is up to the client how best to render embedded resources for the
|
||||||
|
|||||||
21
vendor/github.com/mark3labs/mcp-go/mcp/utils.go
generated
vendored
21
vendor/github.com/mark3labs/mcp-go/mcp/utils.go
generated
vendored
@@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to create a new ResourceLink
|
||||||
|
func NewResourceLink(uri, name, description, mimeType string) ResourceLink {
|
||||||
|
return ResourceLink{
|
||||||
|
Type: "resource_link",
|
||||||
|
URI: uri,
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
MIMEType: mimeType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to create a new EmbeddedResource
|
// Helper function to create a new EmbeddedResource
|
||||||
func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
|
func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
|
||||||
return EmbeddedResource{
|
return EmbeddedResource{
|
||||||
@@ -476,6 +487,16 @@ func ParseContent(contentMap map[string]any) (Content, error) {
|
|||||||
}
|
}
|
||||||
return NewAudioContent(data, mimeType), nil
|
return NewAudioContent(data, mimeType), nil
|
||||||
|
|
||||||
|
case "resource_link":
|
||||||
|
uri := ExtractString(contentMap, "uri")
|
||||||
|
name := ExtractString(contentMap, "name")
|
||||||
|
description := ExtractString(contentMap, "description")
|
||||||
|
mimeType := ExtractString(contentMap, "mimeType")
|
||||||
|
if uri == "" || name == "" {
|
||||||
|
return nil, fmt.Errorf("resource_link uri or name is missing")
|
||||||
|
}
|
||||||
|
return NewResourceLink(uri, name, description, mimeType), nil
|
||||||
|
|
||||||
case "resource":
|
case "resource":
|
||||||
resourceMap := ExtractMap(contentMap, "resource")
|
resourceMap := ExtractMap(contentMap, "resource")
|
||||||
if resourceMap == nil {
|
if resourceMap == nil {
|
||||||
|
|||||||
37
vendor/github.com/mark3labs/mcp-go/server/sampling.go
generated
vendored
Normal file
37
vendor/github.com/mark3labs/mcp-go/server/sampling.go
generated
vendored
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnableSampling enables sampling capabilities for the server.
|
||||||
|
// This allows the server to send sampling requests to clients that support it.
|
||||||
|
func (s *MCPServer) EnableSampling() {
|
||||||
|
s.capabilitiesMu.Lock()
|
||||||
|
defer s.capabilitiesMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestSampling sends a sampling request to the client.
|
||||||
|
// The client must have declared sampling capability during initialization.
|
||||||
|
func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
|
||||||
|
session := ClientSessionFromContext(ctx)
|
||||||
|
if session == nil {
|
||||||
|
return nil, fmt.Errorf("no active session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the session supports sampling requests
|
||||||
|
if samplingSession, ok := session.(SessionWithSampling); ok {
|
||||||
|
return samplingSession.RequestSampling(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("session does not support sampling")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionWithSampling extends ClientSession to support sampling requests.
|
||||||
|
type SessionWithSampling interface {
|
||||||
|
ClientSession
|
||||||
|
RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
|
||||||
|
}
|
||||||
180
vendor/github.com/mark3labs/mcp-go/server/stdio.go
generated
vendored
180
vendor/github.com/mark3labs/mcp-go/server/stdio.go
generated
vendored
@@ -9,6 +9,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
@@ -51,10 +52,21 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
|
|||||||
|
|
||||||
// stdioSession is a static client session, since stdio has only one client.
|
// stdioSession is a static client session, since stdio has only one client.
|
||||||
type stdioSession struct {
|
type stdioSession struct {
|
||||||
notifications chan mcp.JSONRPCNotification
|
notifications chan mcp.JSONRPCNotification
|
||||||
initialized atomic.Bool
|
initialized atomic.Bool
|
||||||
loggingLevel atomic.Value
|
loggingLevel atomic.Value
|
||||||
clientInfo atomic.Value // stores session-specific client info
|
clientInfo atomic.Value // stores session-specific client info
|
||||||
|
writer io.Writer // for sending requests to client
|
||||||
|
requestID atomic.Int64 // for generating unique request IDs
|
||||||
|
mu sync.RWMutex // protects writer
|
||||||
|
pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
|
||||||
|
pendingMu sync.RWMutex // protects pendingRequests
|
||||||
|
}
|
||||||
|
|
||||||
|
// samplingResponse represents a response to a sampling request
|
||||||
|
type samplingResponse struct {
|
||||||
|
result *mcp.CreateMessageResult
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stdioSession) SessionID() string {
|
func (s *stdioSession) SessionID() string {
|
||||||
@@ -100,14 +112,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
|
|||||||
return level.(mcp.LoggingLevel)
|
return level.(mcp.LoggingLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequestSampling sends a sampling request to the client and waits for the response.
|
||||||
|
func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
writer := s.writer
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if writer == nil {
|
||||||
|
return nil, fmt.Errorf("no writer available for sending requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique request ID
|
||||||
|
id := s.requestID.Add(1)
|
||||||
|
|
||||||
|
// Create a response channel for this request
|
||||||
|
responseChan := make(chan *samplingResponse, 1)
|
||||||
|
s.pendingMu.Lock()
|
||||||
|
s.pendingRequests[id] = responseChan
|
||||||
|
s.pendingMu.Unlock()
|
||||||
|
|
||||||
|
// Cleanup function to remove the pending request
|
||||||
|
cleanup := func() {
|
||||||
|
s.pendingMu.Lock()
|
||||||
|
delete(s.pendingRequests, id)
|
||||||
|
s.pendingMu.Unlock()
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create the JSON-RPC request
|
||||||
|
jsonRPCRequest := struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params mcp.CreateMessageParams `json:"params"`
|
||||||
|
}{
|
||||||
|
JSONRPC: mcp.JSONRPC_VERSION,
|
||||||
|
ID: id,
|
||||||
|
Method: string(mcp.MethodSamplingCreateMessage),
|
||||||
|
Params: request.CreateMessageParams,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal and send the request
|
||||||
|
requestBytes, err := json.Marshal(jsonRPCRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
|
||||||
|
}
|
||||||
|
requestBytes = append(requestBytes, '\n')
|
||||||
|
|
||||||
|
if _, err := writer.Write(requestBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write sampling request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the response or context cancellation
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case response := <-responseChan:
|
||||||
|
if response.err != nil {
|
||||||
|
return nil, response.err
|
||||||
|
}
|
||||||
|
return response.result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriter sets the writer for sending requests to the client.
|
||||||
|
func (s *stdioSession) SetWriter(writer io.Writer) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.writer = writer
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_ ClientSession = (*stdioSession)(nil)
|
_ ClientSession = (*stdioSession)(nil)
|
||||||
_ SessionWithLogging = (*stdioSession)(nil)
|
_ SessionWithLogging = (*stdioSession)(nil)
|
||||||
_ SessionWithClientInfo = (*stdioSession)(nil)
|
_ SessionWithClientInfo = (*stdioSession)(nil)
|
||||||
|
_ SessionWithSampling = (*stdioSession)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
var stdioSessionInstance = stdioSession{
|
var stdioSessionInstance = stdioSession{
|
||||||
notifications: make(chan mcp.JSONRPCNotification, 100),
|
notifications: make(chan mcp.JSONRPCNotification, 100),
|
||||||
|
pendingRequests: make(map[int64]chan *samplingResponse),
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
|
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
|
||||||
@@ -224,6 +308,9 @@ func (s *StdioServer) Listen(
|
|||||||
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
|
defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
|
||||||
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
|
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
|
||||||
|
|
||||||
|
// Set the writer for sending requests to the client
|
||||||
|
stdioSessionInstance.SetWriter(stdout)
|
||||||
|
|
||||||
// Add in any custom context.
|
// Add in any custom context.
|
||||||
if s.contextFunc != nil {
|
if s.contextFunc != nil {
|
||||||
ctx = s.contextFunc(ctx)
|
ctx = s.contextFunc(ctx)
|
||||||
@@ -256,7 +343,29 @@ func (s *StdioServer) processMessage(
|
|||||||
return s.writeResponse(response, writer)
|
return s.writeResponse(response, writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the message using the wrapped server
|
// Check if this is a response to a sampling request
|
||||||
|
if s.handleSamplingResponse(rawMessage) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a tool call that might need sampling (and thus should be processed concurrently)
|
||||||
|
var baseMessage struct {
|
||||||
|
Method string `json:"method"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
|
||||||
|
// Process tool calls concurrently to avoid blocking on sampling requests
|
||||||
|
go func() {
|
||||||
|
response := s.server.HandleMessage(ctx, rawMessage)
|
||||||
|
if response != nil {
|
||||||
|
if err := s.writeResponse(response, writer); err != nil {
|
||||||
|
s.errLogger.Printf("Error writing tool response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle other messages synchronously
|
||||||
response := s.server.HandleMessage(ctx, rawMessage)
|
response := s.server.HandleMessage(ctx, rawMessage)
|
||||||
|
|
||||||
// Only write response if there is one (not for notifications)
|
// Only write response if there is one (not for notifications)
|
||||||
@@ -269,6 +378,65 @@ func (s *StdioServer) processMessage(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleSamplingResponse checks if the message is a response to a sampling request
|
||||||
|
// and routes it to the appropriate pending request channel.
|
||||||
|
func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
|
||||||
|
return stdioSessionInstance.handleSamplingResponse(rawMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSamplingResponse handles incoming sampling responses for this session
|
||||||
|
func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
|
||||||
|
// Try to parse as a JSON-RPC response
|
||||||
|
var response struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
ID json.Number `json:"id"`
|
||||||
|
Result json.RawMessage `json:"result,omitempty"`
|
||||||
|
Error *struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(rawMessage, &response); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Parse the ID as int64
|
||||||
|
idInt64, err := response.ID.Int64()
|
||||||
|
if err != nil || (response.Result == nil && response.Error == nil) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for a pending request with this ID
|
||||||
|
s.pendingMu.RLock()
|
||||||
|
responseChan, exists := s.pendingRequests[idInt64]
|
||||||
|
s.pendingMu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
} // Parse and send the response
|
||||||
|
samplingResp := &samplingResponse{}
|
||||||
|
|
||||||
|
if response.Error != nil {
|
||||||
|
samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
|
||||||
|
} else {
|
||||||
|
var result mcp.CreateMessageResult
|
||||||
|
if err := json.Unmarshal(response.Result, &result); err != nil {
|
||||||
|
samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
|
||||||
|
} else {
|
||||||
|
samplingResp.result = &result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the response (non-blocking)
|
||||||
|
select {
|
||||||
|
case responseChan <- samplingResp:
|
||||||
|
default:
|
||||||
|
// Channel is full or closed, ignore
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
|
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
|
||||||
// Returns an error if marshaling or writing fails.
|
// Returns an error if marshaling or writing fails.
|
||||||
func (s *StdioServer) writeResponse(
|
func (s *StdioServer) writeResponse(
|
||||||
|
|||||||
6
vendor/github.com/mark3labs/mcp-go/server/streamable_http.go
generated
vendored
6
vendor/github.com/mark3labs/mcp-go/server/streamable_http.go
generated
vendored
@@ -40,7 +40,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption {
|
|||||||
// to StatelessSessionIdManager.
|
// to StatelessSessionIdManager.
|
||||||
func WithStateLess(stateLess bool) StreamableHTTPOption {
|
func WithStateLess(stateLess bool) StreamableHTTPOption {
|
||||||
return func(s *StreamableHTTPServer) {
|
return func(s *StreamableHTTPServer) {
|
||||||
s.sessionIdManager = &StatelessSessionIdManager{}
|
if stateLess {
|
||||||
|
s.sessionIdManager = &StatelessSessionIdManager{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,7 +376,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
|
|||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
w.Header().Set("Connection", "keep-alive")
|
w.Header().Set("Connection", "keep-alive")
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
2
vendor/modules.txt
vendored
2
vendor/modules.txt
vendored
@@ -403,7 +403,7 @@ github.com/kylelemons/godebug/pretty
|
|||||||
# github.com/lucasb-eyer/go-colorful v1.2.0
|
# github.com/lucasb-eyer/go-colorful v1.2.0
|
||||||
## explicit; go 1.12
|
## explicit; go 1.12
|
||||||
github.com/lucasb-eyer/go-colorful
|
github.com/lucasb-eyer/go-colorful
|
||||||
# github.com/mark3labs/mcp-go v0.32.0
|
# github.com/mark3labs/mcp-go v0.33.0
|
||||||
## explicit; go 1.23
|
## explicit; go 1.23
|
||||||
github.com/mark3labs/mcp-go/client
|
github.com/mark3labs/mcp-go/client
|
||||||
github.com/mark3labs/mcp-go/client/transport
|
github.com/mark3labs/mcp-go/client/transport
|
||||||
|
|||||||
Reference in New Issue
Block a user