mirror of
https://github.com/humanlayer/humanlayer.git
synced 2025-08-20 19:01:22 +03:00
test(mcp): Add integration tests for MCP session correlation and header transmission
- Test Claude's transmission of custom headers including X-Session-ID - Verify approval creation with tool_use_id in Phase 4 implementation - Test auto-approval behavior with dangerously_skip_permissions flag - Confirm session correlation works correctly via HTTP headers
This commit is contained in:
443
hld/daemon/mcp_claudecode_integration_test.go
Normal file
443
hld/daemon/mcp_claudecode_integration_test.go
Normal file
@@ -0,0 +1,443 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MCPRequestLog captures all details of an MCP request
|
||||
type MCPRequestLog struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers map[string][]string
|
||||
Body json.RawMessage
|
||||
Timestamp time.Time
|
||||
RequestID int
|
||||
}
|
||||
|
||||
// MCPTestServer wraps the real MCP server and logs all requests
|
||||
type MCPTestServer struct {
|
||||
realHandler http.Handler
|
||||
requests []MCPRequestLog
|
||||
requestsMutex sync.Mutex
|
||||
requestCount int
|
||||
}
|
||||
|
||||
func (s *MCPTestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture request details
|
||||
s.requestsMutex.Lock()
|
||||
s.requestCount++
|
||||
requestID := s.requestCount
|
||||
|
||||
// Read body for logging
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for real handler
|
||||
|
||||
// Log all headers
|
||||
headersCopy := make(map[string][]string)
|
||||
for k, v := range r.Header {
|
||||
headersCopy[k] = v
|
||||
}
|
||||
|
||||
log := MCPRequestLog{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Headers: headersCopy,
|
||||
Body: bodyBytes,
|
||||
Timestamp: time.Now(),
|
||||
RequestID: requestID,
|
||||
}
|
||||
s.requests = append(s.requests, log)
|
||||
s.requestsMutex.Unlock()
|
||||
|
||||
// Log request details to test output
|
||||
fmt.Printf("\n[MCP Request #%d] %s %s\n", requestID, r.Method, r.URL.Path)
|
||||
fmt.Printf("Headers:\n")
|
||||
for k, v := range r.Header {
|
||||
fmt.Printf(" %s: %s\n", k, strings.Join(v, ", "))
|
||||
}
|
||||
if len(bodyBytes) > 0 {
|
||||
fmt.Printf("Body: %s\n", string(bodyBytes))
|
||||
}
|
||||
fmt.Printf("---\n")
|
||||
|
||||
// Forward to real handler
|
||||
s.realHandler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *MCPTestServer) GetRequests() []MCPRequestLog {
|
||||
s.requestsMutex.Lock()
|
||||
defer s.requestsMutex.Unlock()
|
||||
return append([]MCPRequestLog{}, s.requests...)
|
||||
}
|
||||
|
||||
func TestMCPClaudeCodeSessionIDCorrelation(t *testing.T) {
|
||||
// Skip if Claude is not available
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("Claude CLI not available, skipping integration test")
|
||||
}
|
||||
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-claudecode")
|
||||
dbPath := testutil.DatabasePath(t, "mcp-claudecode")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
os.Setenv("MCP_AUTO_DENY_ALL", "true") // Auto-deny for predictable responses
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create MCP test server wrapper
|
||||
mcpTestServer := &MCPTestServer{
|
||||
requests: []MCPRequestLog{},
|
||||
}
|
||||
|
||||
// Custom HTTP server setup to wrap MCP handler
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.New()
|
||||
|
||||
// Add health endpoint
|
||||
router.GET("/api/v1/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Wrap MCP endpoint with test server
|
||||
router.Any("/api/v1/mcp", func(c *gin.Context) {
|
||||
// First time setup - get real handler from daemon
|
||||
if mcpTestServer.realHandler == nil {
|
||||
// Get the real MCP handler from daemon
|
||||
// We'll create a simple MCP handler that auto-denies
|
||||
mcpTestServer.realHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simple MCP response for testing
|
||||
var req map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
method, _ := req["method"].(string)
|
||||
id := req["id"]
|
||||
|
||||
response := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "initialize":
|
||||
response["result"] = map[string]interface{}{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"serverInfo": map[string]interface{}{
|
||||
"name": "test-mcp-server",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"capabilities": map[string]interface{}{
|
||||
"tools": map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
case "tools/list":
|
||||
response["result"] = map[string]interface{}{
|
||||
"tools": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"description": "Request permission to execute a tool",
|
||||
"inputSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"tool_name": map[string]string{"type": "string"},
|
||||
"input": map[string]string{"type": "object"},
|
||||
"tool_use_id": map[string]string{"type": "string"},
|
||||
},
|
||||
"required": []string{"tool_name", "input", "tool_use_id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case "tools/call":
|
||||
// Auto-deny
|
||||
response["result"] = map[string]interface{}{
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": `{"behavior": "deny", "message": "Auto-denied for testing"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
response["error"] = map[string]interface{}{
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
})
|
||||
}
|
||||
|
||||
mcpTestServer.ServeHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// Start HTTP server
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", httpPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &http.Server{
|
||||
Handler: router,
|
||||
}
|
||||
|
||||
go func() {
|
||||
server.Serve(listener)
|
||||
}()
|
||||
defer server.Shutdown(context.Background())
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create a test session in the database
|
||||
testSessionID := "test-claudecode-session"
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO sessions (
|
||||
id, run_id, claude_session_id, query, model, working_dir,
|
||||
status, created_at, last_activity_at, auto_accept_edits,
|
||||
dangerously_skip_permissions, max_turns, system_prompt,
|
||||
custom_instructions, cost_usd, input_tokens, output_tokens,
|
||||
duration_ms, num_turns, result_content, error_message
|
||||
) VALUES (
|
||||
?, 'run-claudecode', 'claude-test', 'test query', 'claude-3-sonnet', '/tmp',
|
||||
'running', datetime('now'), datetime('now'), 0, 1, 10, '',
|
||||
'', 0.0, 0, 0, 0, 0, '', ''
|
||||
)
|
||||
`, testSessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create claudecode client
|
||||
client, err := claudecode.NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Prepare MCP configuration
|
||||
// The claudecode client will write this to a temp file and pass it to claude
|
||||
// We need to match the format claude expects for HTTP MCP servers
|
||||
mcpConfig := &claudecode.MCPConfig{
|
||||
MCPServers: map[string]claudecode.MCPServer{
|
||||
"humanlayer": {
|
||||
Command: "http", // This is just a placeholder
|
||||
Env: map[string]string{
|
||||
// The actual config will be written as JSON with type/url/headers
|
||||
"_config": fmt.Sprintf(`{"type":"http","url":"%s/api/v1/mcp","headers":{"X-Session-ID":"%s"}}`, baseURL, testSessionID),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create session config
|
||||
sessionConfig := claudecode.SessionConfig{
|
||||
Query: "Say 'test complete' and exit",
|
||||
Model: claudecode.ModelSonnet,
|
||||
OutputFormat: claudecode.OutputStreamJSON,
|
||||
MCPConfig: mcpConfig,
|
||||
PermissionPromptTool: "mcp__humanlayer__request_approval",
|
||||
MaxTurns: 1,
|
||||
WorkingDir: tempDir,
|
||||
Verbose: true,
|
||||
}
|
||||
|
||||
// Capture events from Claude
|
||||
var allEvents []claudecode.StreamEvent
|
||||
var eventsMutex sync.Mutex
|
||||
|
||||
// Launch Claude session
|
||||
t.Log("Launching Claude session with MCP config...")
|
||||
session, err := client.Launch(sessionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Capture events in background
|
||||
eventsDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(eventsDone)
|
||||
for event := range session.Events {
|
||||
eventsMutex.Lock()
|
||||
allEvents = append(allEvents, event)
|
||||
eventsMutex.Unlock()
|
||||
|
||||
// Log significant events
|
||||
switch event.Type {
|
||||
case "system":
|
||||
if event.Subtype == "init" {
|
||||
t.Logf("Claude session initialized: ID=%s, Model=%s", event.SessionID, event.Model)
|
||||
}
|
||||
case "mcp_servers":
|
||||
for _, server := range event.MCPServers {
|
||||
t.Logf("MCP Server %s: %s", server.Name, server.Status)
|
||||
}
|
||||
case "result":
|
||||
t.Logf("Session completed: ID=%s, Error=%v", event.SessionID, event.IsError)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for session to complete (with timeout)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
result, err := session.Wait()
|
||||
if err != nil {
|
||||
t.Logf("Session error: %v", err)
|
||||
} else if result != nil {
|
||||
t.Logf("Session result: %s", result.Result)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Session completed
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Log("Session timeout, interrupting...")
|
||||
session.Interrupt()
|
||||
<-done
|
||||
}
|
||||
|
||||
// Wait for events to be processed
|
||||
<-eventsDone
|
||||
|
||||
// Analyze captured MCP requests
|
||||
requests := mcpTestServer.GetRequests()
|
||||
t.Logf("\n=== MCP Request Analysis ===")
|
||||
t.Logf("Total MCP requests: %d", len(requests))
|
||||
|
||||
// Check for session ID in headers
|
||||
sessionIDFound := false
|
||||
var sessionIDHeaders []string
|
||||
|
||||
for i, req := range requests {
|
||||
t.Logf("\nRequest #%d: %s", i+1, req.Method)
|
||||
|
||||
// Check various possible session ID headers
|
||||
possibleHeaders := []string{
|
||||
"X-Session-ID",
|
||||
"X-Session-Id",
|
||||
"Session-ID",
|
||||
"Session-Id",
|
||||
"Mcp-Session-Id",
|
||||
"MCP-Session-ID",
|
||||
}
|
||||
|
||||
for _, header := range possibleHeaders {
|
||||
if values, ok := req.Headers[header]; ok && len(values) > 0 {
|
||||
sessionIDFound = true
|
||||
sessionIDHeaders = append(sessionIDHeaders, fmt.Sprintf("%s: %s", header, values[0]))
|
||||
t.Logf(" ✓ Found session ID header: %s = %s", header, values[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check if session ID is in the request body
|
||||
if len(req.Body) > 0 {
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(req.Body, &body); err == nil {
|
||||
if sessionID, ok := body["session_id"].(string); ok && sessionID != "" {
|
||||
t.Logf(" ✓ Found session_id in body: %s", sessionID)
|
||||
}
|
||||
if params, ok := body["params"].(map[string]interface{}); ok {
|
||||
if sessionID, ok := params["session_id"].(string); ok && sessionID != "" {
|
||||
t.Logf(" ✓ Found session_id in params: %s", sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Analyze Claude events for session information
|
||||
t.Logf("\n=== Claude Event Analysis ===")
|
||||
t.Logf("Total events captured: %d", len(allEvents))
|
||||
|
||||
var claudeSessionID string
|
||||
for _, event := range allEvents {
|
||||
if event.SessionID != "" && claudeSessionID == "" {
|
||||
claudeSessionID = event.SessionID
|
||||
t.Logf("Claude session ID from events: %s", claudeSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// Final verdict
|
||||
t.Logf("\n=== VERDICT ===")
|
||||
if sessionIDFound {
|
||||
t.Logf("✓ Session ID IS sent in MCP request headers")
|
||||
t.Logf(" Headers found: %s", strings.Join(sessionIDHeaders, ", "))
|
||||
t.Logf(" Current implementation should work correctly")
|
||||
} else {
|
||||
t.Logf("✗ Session ID is NOT sent in MCP request headers")
|
||||
t.Logf(" Claude session ID: %s", claudeSessionID)
|
||||
t.Logf(" Need to implement alternative correlation mechanism")
|
||||
t.Logf(" Possible solutions:")
|
||||
t.Logf(" 1. Use MCP session initialization to establish mapping")
|
||||
t.Logf(" 2. Pass session ID in MCP server URL path")
|
||||
t.Logf(" 3. Use a unique MCP server per session")
|
||||
}
|
||||
|
||||
// Assert findings
|
||||
if !sessionIDFound {
|
||||
t.Error("Session ID is not being sent in MCP request headers - implementation needs revision")
|
||||
|
||||
// Provide detailed recommendations
|
||||
t.Log("\nRECOMMENDED CHANGES:")
|
||||
t.Log("1. Remove reliance on Session-ID header in MCP server")
|
||||
t.Log("2. Consider embedding session ID in MCP server URL:")
|
||||
t.Log(" - Change URL to: /api/v1/mcp/{session_id}")
|
||||
t.Log(" - Extract session ID from URL path in handler")
|
||||
t.Log("3. Or use MCP session correlation:")
|
||||
t.Log(" - Track MCP session ID from initialize method")
|
||||
t.Log(" - Map MCP session to HumanLayer session")
|
||||
}
|
||||
|
||||
// Additional diagnostics
|
||||
t.Logf("\n=== Additional Diagnostics ===")
|
||||
if len(requests) > 0 {
|
||||
t.Log("First request headers:")
|
||||
for k, v := range requests[0].Headers {
|
||||
t.Logf(" %s: %s", k, strings.Join(v, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
404
hld/daemon/mcp_phase4_integration_test.go
Normal file
404
hld/daemon/mcp_phase4_integration_test.go
Normal file
@@ -0,0 +1,404 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/daemon"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMCPPhase4ApprovalCreation(t *testing.T) {
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-phase4")
|
||||
dbPath := testutil.DatabasePath(t, "mcp-phase4")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create daemon
|
||||
d, err := daemon.New()
|
||||
require.NoError(t, err, "Failed to create daemon")
|
||||
|
||||
// Start daemon in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create a test session
|
||||
sessionID := "test-session-phase4"
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO sessions (
|
||||
id, run_id, claude_session_id, query, model, working_dir,
|
||||
status, created_at, last_activity_at, auto_accept_edits,
|
||||
dangerously_skip_permissions, max_turns, system_prompt,
|
||||
custom_instructions, cost_usd, input_tokens, output_tokens,
|
||||
duration_ms, num_turns, result_content, error_message
|
||||
) VALUES (
|
||||
?, 'run-phase4', 'claude-phase4', 'test query', 'claude-3-sonnet', '/tmp',
|
||||
'running', datetime('now'), datetime('now'), 0, 0, 10, '',
|
||||
'', 0.0, 0, 0, 0, 0, '', ''
|
||||
)
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("ApprovalCreatedWithToolUseID", func(t *testing.T) {
|
||||
// Send MCP approval request
|
||||
toolUseID := "test_use_phase4_123"
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "test_tool",
|
||||
"input": map[string]interface{}{"test": "data"},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
// Send request in background (it will block waiting for approval)
|
||||
go func() {
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
client.Do(httpReq)
|
||||
}()
|
||||
|
||||
// Wait for approval to be created
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check database for approval with tool_use_id
|
||||
var count int
|
||||
err := db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Expected exactly one approval with tool_use_id")
|
||||
|
||||
// Verify approval details
|
||||
var approvalID, toolName, status string
|
||||
var toolInput string
|
||||
err = db.QueryRow(`
|
||||
SELECT id, tool_name, tool_input, status
|
||||
FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&approvalID, &toolName, &toolInput, &status)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test_tool", toolName)
|
||||
assert.Equal(t, "pending", status)
|
||||
assert.Contains(t, toolInput, `"test":"data"`)
|
||||
assert.NotEmpty(t, approvalID)
|
||||
|
||||
t.Logf("Successfully created approval with ID=%s, tool_use_id=%s", approvalID, toolUseID)
|
||||
})
|
||||
|
||||
t.Run("AutoApprovalWithDangerousSkip", func(t *testing.T) {
|
||||
// Enable dangerous skip permissions
|
||||
_, err := db.Exec(`
|
||||
UPDATE sessions
|
||||
SET dangerously_skip_permissions = 1
|
||||
WHERE id = ?
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send MCP approval request
|
||||
toolUseID := "test_use_auto_approve"
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "edit_tool",
|
||||
"input": map[string]interface{}{"file": "test.txt"},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should get immediate response due to auto-approval
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
|
||||
// Check that the approval was created and auto-approved
|
||||
var status, comment string
|
||||
err = db.QueryRow(`
|
||||
SELECT status, comment
|
||||
FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&status, &comment)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "approved", status)
|
||||
assert.Contains(t, comment, "dangerous skip permissions")
|
||||
|
||||
// Verify response contains allow behavior
|
||||
if responseContent, ok := result["result"].(map[string]interface{}); ok {
|
||||
if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 {
|
||||
if textContent, ok := content[0].(map[string]interface{}); ok {
|
||||
if text, ok := textContent["text"].(string); ok {
|
||||
var responseData map[string]interface{}
|
||||
json.Unmarshal([]byte(text), &responseData)
|
||||
assert.Equal(t, "allow", responseData["behavior"])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Disable dangerous skip for cleanup
|
||||
_, err = db.Exec(`
|
||||
UPDATE sessions
|
||||
SET dangerously_skip_permissions = 0
|
||||
WHERE id = ?
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("MultipleApprovalsDifferentToolUseIDs", func(t *testing.T) {
|
||||
// Create multiple approval requests with different tool_use_ids
|
||||
toolUseIDs := []string{"multi_1", "multi_2", "multi_3"}
|
||||
|
||||
for _, toolUseID := range toolUseIDs {
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": toolUseID,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "multi_tool",
|
||||
"input": map[string]interface{}{"id": toolUseID},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
// Send requests in background
|
||||
go func() {
|
||||
client := &http.Client{Timeout: 1 * time.Second}
|
||||
client.Do(httpReq)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for approvals to be created
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Verify all approvals were created with correct tool_use_ids
|
||||
for _, toolUseID := range toolUseIDs {
|
||||
var count int
|
||||
err := db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Expected approval for tool_use_id=%s", toolUseID)
|
||||
}
|
||||
|
||||
// Verify total count
|
||||
var totalCount int
|
||||
err := db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id IN ('multi_1', 'multi_2', 'multi_3')
|
||||
AND session_id = ?
|
||||
`, sessionID).Scan(&totalCount)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, totalCount, "Expected exactly 3 approvals")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMCPPhase4AutoDenyMode(t *testing.T) {
|
||||
// Set auto-deny mode
|
||||
os.Setenv("MCP_AUTO_DENY_ALL", "true")
|
||||
defer os.Unsetenv("MCP_AUTO_DENY_ALL")
|
||||
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-phase4-autodeny")
|
||||
dbPath := testutil.DatabasePath(t, "mcp-phase4-autodeny")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create daemon
|
||||
d, err := daemon.New()
|
||||
require.NoError(t, err, "Failed to create daemon")
|
||||
|
||||
// Start daemon in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create a test session
|
||||
sessionID := "test-session-autodeny"
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO sessions (
|
||||
id, run_id, claude_session_id, query, model, working_dir,
|
||||
status, created_at, last_activity_at, auto_accept_edits,
|
||||
dangerously_skip_permissions, max_turns, system_prompt,
|
||||
custom_instructions, cost_usd, input_tokens, output_tokens,
|
||||
duration_ms, num_turns, result_content, error_message
|
||||
) VALUES (
|
||||
?, 'run-autodeny', 'claude-autodeny', 'test query', 'claude-3-sonnet', '/tmp',
|
||||
'running', datetime('now'), datetime('now'), 0, 0, 10, '',
|
||||
'', 0.0, 0, 0, 0, 0, '', ''
|
||||
)
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("AutoDenyDoesNotCreateApproval", func(t *testing.T) {
|
||||
// Send MCP approval request
|
||||
toolUseID := "test_autodeny"
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "test_tool",
|
||||
"input": map[string]interface{}{"test": "data"},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should get immediate deny response
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
|
||||
// Verify deny response
|
||||
if responseContent, ok := result["result"].(map[string]interface{}); ok {
|
||||
if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 {
|
||||
if textContent, ok := content[0].(map[string]interface{}); ok {
|
||||
if text, ok := textContent["text"].(string); ok {
|
||||
var responseData map[string]interface{}
|
||||
json.Unmarshal([]byte(text), &responseData)
|
||||
assert.Equal(t, "deny", responseData["behavior"])
|
||||
assert.Contains(t, responseData["message"], "Auto-denied")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no approval was created in database
|
||||
var count int
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id = ?
|
||||
`, toolUseID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count, "No approval should be created in auto-deny mode")
|
||||
})
|
||||
}
|
||||
282
hld/daemon/mcp_session_header_test.go
Normal file
282
hld/daemon/mcp_session_header_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// HeaderCaptureServer captures all HTTP request headers for analysis
|
||||
type HeaderCaptureServer struct {
|
||||
mu sync.Mutex
|
||||
requests []RequestCapture
|
||||
}
|
||||
|
||||
type RequestCapture struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers map[string][]string
|
||||
Body []byte
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
func (s *HeaderCaptureServer) Handler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture request
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
capture := RequestCapture{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Headers: r.Header.Clone(),
|
||||
Body: body,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.requests = append(s.requests, capture)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Log headers
|
||||
fmt.Printf("\n[%s %s]\n", r.Method, r.URL.Path)
|
||||
fmt.Println("Headers:")
|
||||
for k, v := range r.Header {
|
||||
fmt.Printf(" %s: %s\n", k, strings.Join(v, ", "))
|
||||
}
|
||||
|
||||
// Parse JSON-RPC request
|
||||
var req map[string]interface{}
|
||||
json.Unmarshal(body, &req)
|
||||
|
||||
method, _ := req["method"].(string)
|
||||
id := req["id"]
|
||||
|
||||
// Create response
|
||||
response := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
}
|
||||
|
||||
// Handle MCP methods
|
||||
switch method {
|
||||
case "initialize":
|
||||
response["result"] = map[string]interface{}{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"serverInfo": map[string]interface{}{
|
||||
"name": "test-server",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"capabilities": map[string]interface{}{
|
||||
"tools": map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
case "tools/list":
|
||||
response["result"] = map[string]interface{}{
|
||||
"tools": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"inputSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case "tools/call":
|
||||
response["result"] = map[string]interface{}{
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "Test response",
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
response["error"] = map[string]interface{}{
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeaderCaptureServer) GetRequests() []RequestCapture {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return append([]RequestCapture{}, s.requests...)
|
||||
}
|
||||
|
||||
func TestMCPHeaderTransmission(t *testing.T) {
|
||||
// Skip if Claude CLI is not available
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("Claude CLI not available")
|
||||
}
|
||||
|
||||
// Create header capture server
|
||||
captureServer := &HeaderCaptureServer{}
|
||||
|
||||
// Start HTTP server
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.New()
|
||||
router.Any("/mcp", gin.WrapF(captureServer.Handler()))
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||
|
||||
server := &http.Server{Handler: router}
|
||||
go server.Serve(listener)
|
||||
defer server.Shutdown(context.Background())
|
||||
|
||||
// Create MCP config file with custom headers
|
||||
tempDir := t.TempDir()
|
||||
mcpConfigPath := fmt.Sprintf("%s/mcp-config.json", tempDir)
|
||||
|
||||
testSessionID := "test-session-123"
|
||||
mcpConfig := map[string]interface{}{
|
||||
"mcpServers": map[string]interface{}{
|
||||
"test": map[string]interface{}{
|
||||
"type": "http",
|
||||
"url": fmt.Sprintf("%s/mcp", baseURL),
|
||||
"headers": map[string]string{
|
||||
"X-Session-ID": testSessionID,
|
||||
"X-Custom-Header": "custom-value",
|
||||
"Authorization": "Bearer test-token",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
configBytes, _ := json.MarshalIndent(mcpConfig, "", " ")
|
||||
err = os.WriteFile(mcpConfigPath, configBytes, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("MCP Config:\n%s", string(configBytes))
|
||||
|
||||
// Launch Claude with MCP config
|
||||
cmd := exec.Command("claude",
|
||||
"--print", "test",
|
||||
"--mcp-config", mcpConfigPath,
|
||||
"--max-turns", "1",
|
||||
"--model", "sonnet",
|
||||
"--output-format", "json",
|
||||
)
|
||||
cmd.Dir = tempDir
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
t.Logf("Claude output:\n%s", string(output))
|
||||
|
||||
// Allow some time for any async operations
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Analyze captured requests
|
||||
requests := captureServer.GetRequests()
|
||||
t.Logf("\n=== Captured %d MCP Requests ===", len(requests))
|
||||
|
||||
sessionIDFound := false
|
||||
customHeaderFound := false
|
||||
authHeaderFound := false
|
||||
|
||||
for i, req := range requests {
|
||||
t.Logf("\nRequest #%d: %s %s", i+1, req.Method, req.Path)
|
||||
|
||||
// Check for our custom headers
|
||||
if sessionIDs, ok := req.Headers["X-Session-Id"]; ok && len(sessionIDs) > 0 {
|
||||
sessionIDFound = true
|
||||
sessionID := sessionIDs[0]
|
||||
t.Logf(" ✓ X-Session-ID: %s", sessionID)
|
||||
if sessionID != testSessionID {
|
||||
t.Errorf(" ✗ Session ID mismatch: got %s, want %s", sessionID, testSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
if customs, ok := req.Headers["X-Custom-Header"]; ok && len(customs) > 0 {
|
||||
customHeaderFound = true
|
||||
t.Logf(" ✓ X-Custom-Header: %s", customs[0])
|
||||
}
|
||||
|
||||
if auths, ok := req.Headers["Authorization"]; ok && len(auths) > 0 {
|
||||
authHeaderFound = true
|
||||
t.Logf(" ✓ Authorization: %s", auths[0])
|
||||
}
|
||||
|
||||
// Log all headers for debugging
|
||||
t.Log(" All headers:")
|
||||
for k, v := range req.Headers {
|
||||
t.Logf(" %s: %s", k, strings.Join(v, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// Verdict
|
||||
t.Log("\n=== VERDICT ===")
|
||||
if sessionIDFound && customHeaderFound && authHeaderFound {
|
||||
t.Log("✓ All custom headers are transmitted correctly")
|
||||
t.Log("✓ Session ID correlation via headers WILL work")
|
||||
} else {
|
||||
t.Log("✗ Custom headers are NOT being transmitted")
|
||||
if !sessionIDFound {
|
||||
t.Error("X-Session-ID header not found in MCP requests")
|
||||
}
|
||||
if !customHeaderFound {
|
||||
t.Error("X-Custom-Header not found in MCP requests")
|
||||
}
|
||||
if !authHeaderFound {
|
||||
t.Error("Authorization header not found in MCP requests")
|
||||
}
|
||||
|
||||
t.Log("\n=== RECOMMENDATIONS ===")
|
||||
t.Log("1. Embed session ID in the MCP server URL path")
|
||||
t.Log("2. Use unique MCP server instances per session")
|
||||
t.Log("3. Implement session correlation via MCP protocol messages")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPSessionCorrelationAlternatives(t *testing.T) {
|
||||
t.Log("\n=== Alternative Session Correlation Methods ===")
|
||||
|
||||
t.Log("\n1. URL Path Embedding:")
|
||||
t.Log(" - Change MCP endpoint to: /api/v1/mcp/:session_id")
|
||||
t.Log(" - Extract session ID from URL path in handler")
|
||||
t.Log(" - Pro: Simple, reliable, no header dependency")
|
||||
t.Log(" - Con: Requires URL generation per session")
|
||||
|
||||
t.Log("\n2. MCP Protocol Session:")
|
||||
t.Log(" - Use MCP's initialize response to establish session")
|
||||
t.Log(" - Store mapping: mcp_session_id -> humanlayer_session_id")
|
||||
t.Log(" - Pro: Protocol-native solution")
|
||||
t.Log(" - Con: Requires stateful MCP server")
|
||||
|
||||
t.Log("\n3. Token-based Correlation:")
|
||||
t.Log(" - Generate unique token per session")
|
||||
t.Log(" - Pass token in MCP server name or URL")
|
||||
t.Log(" - Pro: Secure, unique per session")
|
||||
t.Log(" - Con: Token management complexity")
|
||||
|
||||
t.Log("\n4. Process-based Correlation:")
|
||||
t.Log(" - Track Claude process ID")
|
||||
t.Log(" - Map process to session at launch")
|
||||
t.Log(" - Pro: OS-level tracking")
|
||||
t.Log(" - Con: Complex, platform-specific")
|
||||
}
|
||||
Reference in New Issue
Block a user