test(mcp): Add integration tests for MCP session correlation and header transmission

- Test Claude's transmission of custom headers including X-Session-ID
- Verify approval creation with tool_use_id in Phase 4 implementation
- Test auto-approval behavior with dangerously_skip_permissions flag
- Confirm session correlation works correctly via HTTP headers
This commit is contained in:
dexhorthy
2025-08-13 19:53:23 -07:00
parent c0b0335da2
commit d4d7c39fef
3 changed files with 1129 additions and 0 deletions

View File

@@ -0,0 +1,443 @@
//go:build integration
package daemon_test
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
"github.com/humanlayer/humanlayer/hld/internal/testutil"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
// MCPRequestLog captures all details of an MCP request
type MCPRequestLog struct {
Method string
Path string
Headers map[string][]string
Body json.RawMessage
Timestamp time.Time
RequestID int
}
// MCPTestServer wraps the real MCP server and logs all requests
type MCPTestServer struct {
realHandler http.Handler
requests []MCPRequestLog
requestsMutex sync.Mutex
requestCount int
}
func (s *MCPTestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Capture request details
s.requestsMutex.Lock()
s.requestCount++
requestID := s.requestCount
// Read body for logging
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for real handler
// Log all headers
headersCopy := make(map[string][]string)
for k, v := range r.Header {
headersCopy[k] = v
}
log := MCPRequestLog{
Method: r.Method,
Path: r.URL.Path,
Headers: headersCopy,
Body: bodyBytes,
Timestamp: time.Now(),
RequestID: requestID,
}
s.requests = append(s.requests, log)
s.requestsMutex.Unlock()
// Log request details to test output
fmt.Printf("\n[MCP Request #%d] %s %s\n", requestID, r.Method, r.URL.Path)
fmt.Printf("Headers:\n")
for k, v := range r.Header {
fmt.Printf(" %s: %s\n", k, strings.Join(v, ", "))
}
if len(bodyBytes) > 0 {
fmt.Printf("Body: %s\n", string(bodyBytes))
}
fmt.Printf("---\n")
// Forward to real handler
s.realHandler.ServeHTTP(w, r)
}
func (s *MCPTestServer) GetRequests() []MCPRequestLog {
s.requestsMutex.Lock()
defer s.requestsMutex.Unlock()
return append([]MCPRequestLog{}, s.requests...)
}
func TestMCPClaudeCodeSessionIDCorrelation(t *testing.T) {
// Skip if Claude is not available
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("Claude CLI not available, skipping integration test")
}
// Setup isolated environment
socketPath := testutil.SocketPath(t, "mcp-claudecode")
dbPath := testutil.DatabasePath(t, "mcp-claudecode")
// Get a free port for HTTP server
httpPort := getFreePort(t)
// Override environment
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
os.Setenv("MCP_AUTO_DENY_ALL", "true") // Auto-deny for predictable responses
// Create isolated config
tempDir := t.TempDir()
os.Setenv("XDG_CONFIG_HOME", tempDir)
configDir := filepath.Join(tempDir, "humanlayer")
require.NoError(t, os.MkdirAll(configDir, 0755))
configFile := filepath.Join(configDir, "humanlayer.json")
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
// Create MCP test server wrapper
mcpTestServer := &MCPTestServer{
requests: []MCPRequestLog{},
}
// Custom HTTP server setup to wrap MCP handler
gin.SetMode(gin.ReleaseMode)
router := gin.New()
// Add health endpoint
router.GET("/api/v1/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})
// Wrap MCP endpoint with test server
router.Any("/api/v1/mcp", func(c *gin.Context) {
// First time setup - get real handler from daemon
if mcpTestServer.realHandler == nil {
// Get the real MCP handler from daemon
// We'll create a simple MCP handler that auto-denies
mcpTestServer.realHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simple MCP response for testing
var req map[string]interface{}
json.NewDecoder(r.Body).Decode(&req)
method, _ := req["method"].(string)
id := req["id"]
response := map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
}
switch method {
case "initialize":
response["result"] = map[string]interface{}{
"protocolVersion": "2025-03-26",
"serverInfo": map[string]interface{}{
"name": "test-mcp-server",
"version": "1.0.0",
},
"capabilities": map[string]interface{}{
"tools": map[string]interface{}{},
},
}
case "tools/list":
response["result"] = map[string]interface{}{
"tools": []interface{}{
map[string]interface{}{
"name": "request_approval",
"description": "Request permission to execute a tool",
"inputSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"tool_name": map[string]string{"type": "string"},
"input": map[string]string{"type": "object"},
"tool_use_id": map[string]string{"type": "string"},
},
"required": []string{"tool_name", "input", "tool_use_id"},
},
},
},
}
case "tools/call":
// Auto-deny
response["result"] = map[string]interface{}{
"content": []interface{}{
map[string]interface{}{
"type": "text",
"text": `{"behavior": "deny", "message": "Auto-denied for testing"}`,
},
},
}
default:
response["error"] = map[string]interface{}{
"code": -32601,
"message": "Method not found",
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
})
}
mcpTestServer.ServeHTTP(c.Writer, c.Request)
})
// Start HTTP server
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", httpPort))
require.NoError(t, err)
server := &http.Server{
Handler: router,
}
go func() {
server.Serve(listener)
}()
defer server.Shutdown(context.Background())
// Wait for HTTP server to be ready
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
require.Eventually(t, func() bool {
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
if err == nil {
resp.Body.Close()
return resp.StatusCode == 200
}
return false
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
// Open database connection
db, err := sql.Open("sqlite3", dbPath)
require.NoError(t, err)
defer db.Close()
// Create a test session in the database
testSessionID := "test-claudecode-session"
_, err = db.Exec(`
INSERT INTO sessions (
id, run_id, claude_session_id, query, model, working_dir,
status, created_at, last_activity_at, auto_accept_edits,
dangerously_skip_permissions, max_turns, system_prompt,
custom_instructions, cost_usd, input_tokens, output_tokens,
duration_ms, num_turns, result_content, error_message
) VALUES (
?, 'run-claudecode', 'claude-test', 'test query', 'claude-3-sonnet', '/tmp',
'running', datetime('now'), datetime('now'), 0, 1, 10, '',
'', 0.0, 0, 0, 0, 0, '', ''
)
`, testSessionID)
require.NoError(t, err)
// Create claudecode client
client, err := claudecode.NewClient()
require.NoError(t, err)
// Prepare MCP configuration
// The claudecode client will write this to a temp file and pass it to claude
// We need to match the format claude expects for HTTP MCP servers
mcpConfig := &claudecode.MCPConfig{
MCPServers: map[string]claudecode.MCPServer{
"humanlayer": {
Command: "http", // This is just a placeholder
Env: map[string]string{
// The actual config will be written as JSON with type/url/headers
"_config": fmt.Sprintf(`{"type":"http","url":"%s/api/v1/mcp","headers":{"X-Session-ID":"%s"}}`, baseURL, testSessionID),
},
},
},
}
// Create session config
sessionConfig := claudecode.SessionConfig{
Query: "Say 'test complete' and exit",
Model: claudecode.ModelSonnet,
OutputFormat: claudecode.OutputStreamJSON,
MCPConfig: mcpConfig,
PermissionPromptTool: "mcp__humanlayer__request_approval",
MaxTurns: 1,
WorkingDir: tempDir,
Verbose: true,
}
// Capture events from Claude
var allEvents []claudecode.StreamEvent
var eventsMutex sync.Mutex
// Launch Claude session
t.Log("Launching Claude session with MCP config...")
session, err := client.Launch(sessionConfig)
require.NoError(t, err)
// Capture events in background
eventsDone := make(chan struct{})
go func() {
defer close(eventsDone)
for event := range session.Events {
eventsMutex.Lock()
allEvents = append(allEvents, event)
eventsMutex.Unlock()
// Log significant events
switch event.Type {
case "system":
if event.Subtype == "init" {
t.Logf("Claude session initialized: ID=%s, Model=%s", event.SessionID, event.Model)
}
case "mcp_servers":
for _, server := range event.MCPServers {
t.Logf("MCP Server %s: %s", server.Name, server.Status)
}
case "result":
t.Logf("Session completed: ID=%s, Error=%v", event.SessionID, event.IsError)
}
}
}()
// Wait for session to complete (with timeout)
done := make(chan struct{})
go func() {
defer close(done)
result, err := session.Wait()
if err != nil {
t.Logf("Session error: %v", err)
} else if result != nil {
t.Logf("Session result: %s", result.Result)
}
}()
select {
case <-done:
// Session completed
case <-time.After(30 * time.Second):
t.Log("Session timeout, interrupting...")
session.Interrupt()
<-done
}
// Wait for events to be processed
<-eventsDone
// Analyze captured MCP requests
requests := mcpTestServer.GetRequests()
t.Logf("\n=== MCP Request Analysis ===")
t.Logf("Total MCP requests: %d", len(requests))
// Check for session ID in headers
sessionIDFound := false
var sessionIDHeaders []string
for i, req := range requests {
t.Logf("\nRequest #%d: %s", i+1, req.Method)
// Check various possible session ID headers
possibleHeaders := []string{
"X-Session-ID",
"X-Session-Id",
"Session-ID",
"Session-Id",
"Mcp-Session-Id",
"MCP-Session-ID",
}
for _, header := range possibleHeaders {
if values, ok := req.Headers[header]; ok && len(values) > 0 {
sessionIDFound = true
sessionIDHeaders = append(sessionIDHeaders, fmt.Sprintf("%s: %s", header, values[0]))
t.Logf(" ✓ Found session ID header: %s = %s", header, values[0])
}
}
// Check if session ID is in the request body
if len(req.Body) > 0 {
var body map[string]interface{}
if err := json.Unmarshal(req.Body, &body); err == nil {
if sessionID, ok := body["session_id"].(string); ok && sessionID != "" {
t.Logf(" ✓ Found session_id in body: %s", sessionID)
}
if params, ok := body["params"].(map[string]interface{}); ok {
if sessionID, ok := params["session_id"].(string); ok && sessionID != "" {
t.Logf(" ✓ Found session_id in params: %s", sessionID)
}
}
}
}
}
// Analyze Claude events for session information
t.Logf("\n=== Claude Event Analysis ===")
t.Logf("Total events captured: %d", len(allEvents))
var claudeSessionID string
for _, event := range allEvents {
if event.SessionID != "" && claudeSessionID == "" {
claudeSessionID = event.SessionID
t.Logf("Claude session ID from events: %s", claudeSessionID)
}
}
// Final verdict
t.Logf("\n=== VERDICT ===")
if sessionIDFound {
t.Logf("✓ Session ID IS sent in MCP request headers")
t.Logf(" Headers found: %s", strings.Join(sessionIDHeaders, ", "))
t.Logf(" Current implementation should work correctly")
} else {
t.Logf("✗ Session ID is NOT sent in MCP request headers")
t.Logf(" Claude session ID: %s", claudeSessionID)
t.Logf(" Need to implement alternative correlation mechanism")
t.Logf(" Possible solutions:")
t.Logf(" 1. Use MCP session initialization to establish mapping")
t.Logf(" 2. Pass session ID in MCP server URL path")
t.Logf(" 3. Use a unique MCP server per session")
}
// Assert findings
if !sessionIDFound {
t.Error("Session ID is not being sent in MCP request headers - implementation needs revision")
// Provide detailed recommendations
t.Log("\nRECOMMENDED CHANGES:")
t.Log("1. Remove reliance on Session-ID header in MCP server")
t.Log("2. Consider embedding session ID in MCP server URL:")
t.Log(" - Change URL to: /api/v1/mcp/{session_id}")
t.Log(" - Extract session ID from URL path in handler")
t.Log("3. Or use MCP session correlation:")
t.Log(" - Track MCP session ID from initialize method")
t.Log(" - Map MCP session to HumanLayer session")
}
// Additional diagnostics
t.Logf("\n=== Additional Diagnostics ===")
if len(requests) > 0 {
t.Log("First request headers:")
for k, v := range requests[0].Headers {
t.Logf(" %s: %s", k, strings.Join(v, ", "))
}
}
}

View File

@@ -0,0 +1,404 @@
//go:build integration
package daemon_test
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/humanlayer/humanlayer/hld/daemon"
"github.com/humanlayer/humanlayer/hld/internal/testutil"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMCPPhase4ApprovalCreation(t *testing.T) {
// Setup isolated environment
socketPath := testutil.SocketPath(t, "mcp-phase4")
dbPath := testutil.DatabasePath(t, "mcp-phase4")
// Get a free port for HTTP server
httpPort := getFreePort(t)
// Override environment
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
// Create isolated config
tempDir := t.TempDir()
os.Setenv("XDG_CONFIG_HOME", tempDir)
configDir := filepath.Join(tempDir, "humanlayer")
require.NoError(t, os.MkdirAll(configDir, 0755))
configFile := filepath.Join(configDir, "humanlayer.json")
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
// Create daemon
d, err := daemon.New()
require.NoError(t, err, "Failed to create daemon")
// Start daemon in background
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- d.Run(ctx)
}()
// Wait for HTTP server to be ready
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
require.Eventually(t, func() bool {
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
if err == nil {
resp.Body.Close()
return resp.StatusCode == 200
}
return false
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
// Open database connection
db, err := sql.Open("sqlite3", dbPath)
require.NoError(t, err)
defer db.Close()
// Create a test session
sessionID := "test-session-phase4"
_, err = db.Exec(`
INSERT INTO sessions (
id, run_id, claude_session_id, query, model, working_dir,
status, created_at, last_activity_at, auto_accept_edits,
dangerously_skip_permissions, max_turns, system_prompt,
custom_instructions, cost_usd, input_tokens, output_tokens,
duration_ms, num_turns, result_content, error_message
) VALUES (
?, 'run-phase4', 'claude-phase4', 'test query', 'claude-3-sonnet', '/tmp',
'running', datetime('now'), datetime('now'), 0, 0, 10, '',
'', 0.0, 0, 0, 0, 0, '', ''
)
`, sessionID)
require.NoError(t, err)
t.Run("ApprovalCreatedWithToolUseID", func(t *testing.T) {
// Send MCP approval request
toolUseID := "test_use_phase4_123"
req := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": map[string]interface{}{
"name": "request_approval",
"arguments": map[string]interface{}{
"tool_name": "test_tool",
"input": map[string]interface{}{"test": "data"},
"tool_use_id": toolUseID,
},
},
}
body, _ := json.Marshal(req)
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("X-Session-ID", sessionID)
// Send request in background (it will block waiting for approval)
go func() {
client := &http.Client{Timeout: 2 * time.Second}
client.Do(httpReq)
}()
// Wait for approval to be created
time.Sleep(500 * time.Millisecond)
// Check database for approval with tool_use_id
var count int
err := db.QueryRow(`
SELECT COUNT(*) FROM approvals
WHERE tool_use_id = ? AND session_id = ?
`, toolUseID, sessionID).Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count, "Expected exactly one approval with tool_use_id")
// Verify approval details
var approvalID, toolName, status string
var toolInput string
err = db.QueryRow(`
SELECT id, tool_name, tool_input, status
FROM approvals
WHERE tool_use_id = ? AND session_id = ?
`, toolUseID, sessionID).Scan(&approvalID, &toolName, &toolInput, &status)
require.NoError(t, err)
assert.Equal(t, "test_tool", toolName)
assert.Equal(t, "pending", status)
assert.Contains(t, toolInput, `"test":"data"`)
assert.NotEmpty(t, approvalID)
t.Logf("Successfully created approval with ID=%s, tool_use_id=%s", approvalID, toolUseID)
})
t.Run("AutoApprovalWithDangerousSkip", func(t *testing.T) {
// Enable dangerous skip permissions
_, err := db.Exec(`
UPDATE sessions
SET dangerously_skip_permissions = 1
WHERE id = ?
`, sessionID)
require.NoError(t, err)
// Send MCP approval request
toolUseID := "test_use_auto_approve"
req := map[string]interface{}{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": map[string]interface{}{
"name": "request_approval",
"arguments": map[string]interface{}{
"tool_name": "edit_tool",
"input": map[string]interface{}{"file": "test.txt"},
"tool_use_id": toolUseID,
},
},
}
body, _ := json.Marshal(req)
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("X-Session-ID", sessionID)
resp, err := http.DefaultClient.Do(httpReq)
require.NoError(t, err)
defer resp.Body.Close()
// Should get immediate response due to auto-approval
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
// Check that the approval was created and auto-approved
var status, comment string
err = db.QueryRow(`
SELECT status, comment
FROM approvals
WHERE tool_use_id = ? AND session_id = ?
`, toolUseID, sessionID).Scan(&status, &comment)
require.NoError(t, err)
assert.Equal(t, "approved", status)
assert.Contains(t, comment, "dangerous skip permissions")
// Verify response contains allow behavior
if responseContent, ok := result["result"].(map[string]interface{}); ok {
if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 {
if textContent, ok := content[0].(map[string]interface{}); ok {
if text, ok := textContent["text"].(string); ok {
var responseData map[string]interface{}
json.Unmarshal([]byte(text), &responseData)
assert.Equal(t, "allow", responseData["behavior"])
}
}
}
}
// Disable dangerous skip for cleanup
_, err = db.Exec(`
UPDATE sessions
SET dangerously_skip_permissions = 0
WHERE id = ?
`, sessionID)
require.NoError(t, err)
})
t.Run("MultipleApprovalsDifferentToolUseIDs", func(t *testing.T) {
// Create multiple approval requests with different tool_use_ids
toolUseIDs := []string{"multi_1", "multi_2", "multi_3"}
for _, toolUseID := range toolUseIDs {
req := map[string]interface{}{
"jsonrpc": "2.0",
"id": toolUseID,
"method": "tools/call",
"params": map[string]interface{}{
"name": "request_approval",
"arguments": map[string]interface{}{
"tool_name": "multi_tool",
"input": map[string]interface{}{"id": toolUseID},
"tool_use_id": toolUseID,
},
},
}
body, _ := json.Marshal(req)
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("X-Session-ID", sessionID)
// Send requests in background
go func() {
client := &http.Client{Timeout: 1 * time.Second}
client.Do(httpReq)
}()
}
// Wait for approvals to be created
time.Sleep(500 * time.Millisecond)
// Verify all approvals were created with correct tool_use_ids
for _, toolUseID := range toolUseIDs {
var count int
err := db.QueryRow(`
SELECT COUNT(*) FROM approvals
WHERE tool_use_id = ? AND session_id = ?
`, toolUseID, sessionID).Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count, "Expected approval for tool_use_id=%s", toolUseID)
}
// Verify total count
var totalCount int
err := db.QueryRow(`
SELECT COUNT(*) FROM approvals
WHERE tool_use_id IN ('multi_1', 'multi_2', 'multi_3')
AND session_id = ?
`, sessionID).Scan(&totalCount)
require.NoError(t, err)
assert.Equal(t, 3, totalCount, "Expected exactly 3 approvals")
})
}
func TestMCPPhase4AutoDenyMode(t *testing.T) {
// Set auto-deny mode
os.Setenv("MCP_AUTO_DENY_ALL", "true")
defer os.Unsetenv("MCP_AUTO_DENY_ALL")
// Setup isolated environment
socketPath := testutil.SocketPath(t, "mcp-phase4-autodeny")
dbPath := testutil.DatabasePath(t, "mcp-phase4-autodeny")
// Get a free port for HTTP server
httpPort := getFreePort(t)
// Override environment
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
// Create isolated config
tempDir := t.TempDir()
os.Setenv("XDG_CONFIG_HOME", tempDir)
configDir := filepath.Join(tempDir, "humanlayer")
require.NoError(t, os.MkdirAll(configDir, 0755))
configFile := filepath.Join(configDir, "humanlayer.json")
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
// Create daemon
d, err := daemon.New()
require.NoError(t, err, "Failed to create daemon")
// Start daemon in background
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- d.Run(ctx)
}()
// Wait for HTTP server to be ready
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
require.Eventually(t, func() bool {
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
if err == nil {
resp.Body.Close()
return resp.StatusCode == 200
}
return false
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
// Open database connection
db, err := sql.Open("sqlite3", dbPath)
require.NoError(t, err)
defer db.Close()
// Create a test session
sessionID := "test-session-autodeny"
_, err = db.Exec(`
INSERT INTO sessions (
id, run_id, claude_session_id, query, model, working_dir,
status, created_at, last_activity_at, auto_accept_edits,
dangerously_skip_permissions, max_turns, system_prompt,
custom_instructions, cost_usd, input_tokens, output_tokens,
duration_ms, num_turns, result_content, error_message
) VALUES (
?, 'run-autodeny', 'claude-autodeny', 'test query', 'claude-3-sonnet', '/tmp',
'running', datetime('now'), datetime('now'), 0, 0, 10, '',
'', 0.0, 0, 0, 0, 0, '', ''
)
`, sessionID)
require.NoError(t, err)
t.Run("AutoDenyDoesNotCreateApproval", func(t *testing.T) {
// Send MCP approval request
toolUseID := "test_autodeny"
req := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": map[string]interface{}{
"name": "request_approval",
"arguments": map[string]interface{}{
"tool_name": "test_tool",
"input": map[string]interface{}{"test": "data"},
"tool_use_id": toolUseID,
},
},
}
body, _ := json.Marshal(req)
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("X-Session-ID", sessionID)
resp, err := http.DefaultClient.Do(httpReq)
require.NoError(t, err)
defer resp.Body.Close()
// Should get immediate deny response
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
// Verify deny response
if responseContent, ok := result["result"].(map[string]interface{}); ok {
if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 {
if textContent, ok := content[0].(map[string]interface{}); ok {
if text, ok := textContent["text"].(string); ok {
var responseData map[string]interface{}
json.Unmarshal([]byte(text), &responseData)
assert.Equal(t, "deny", responseData["behavior"])
assert.Contains(t, responseData["message"], "Auto-denied")
}
}
}
}
// Verify no approval was created in database
var count int
err = db.QueryRow(`
SELECT COUNT(*) FROM approvals
WHERE tool_use_id = ?
`, toolUseID).Scan(&count)
require.NoError(t, err)
assert.Equal(t, 0, count, "No approval should be created in auto-deny mode")
})
}

View File

@@ -0,0 +1,282 @@
//go:build integration
package daemon_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// HeaderCaptureServer captures all HTTP request headers for analysis
type HeaderCaptureServer struct {
mu sync.Mutex
requests []RequestCapture
}
type RequestCapture struct {
Method string
Path string
Headers map[string][]string
Body []byte
Timestamp time.Time
}
func (s *HeaderCaptureServer) Handler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Capture request
body, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewReader(body))
capture := RequestCapture{
Method: r.Method,
Path: r.URL.Path,
Headers: r.Header.Clone(),
Body: body,
Timestamp: time.Now(),
}
s.mu.Lock()
s.requests = append(s.requests, capture)
s.mu.Unlock()
// Log headers
fmt.Printf("\n[%s %s]\n", r.Method, r.URL.Path)
fmt.Println("Headers:")
for k, v := range r.Header {
fmt.Printf(" %s: %s\n", k, strings.Join(v, ", "))
}
// Parse JSON-RPC request
var req map[string]interface{}
json.Unmarshal(body, &req)
method, _ := req["method"].(string)
id := req["id"]
// Create response
response := map[string]interface{}{
"jsonrpc": "2.0",
"id": id,
}
// Handle MCP methods
switch method {
case "initialize":
response["result"] = map[string]interface{}{
"protocolVersion": "2025-03-26",
"serverInfo": map[string]interface{}{
"name": "test-server",
"version": "1.0.0",
},
"capabilities": map[string]interface{}{
"tools": map[string]interface{}{},
},
}
case "tools/list":
response["result"] = map[string]interface{}{
"tools": []interface{}{
map[string]interface{}{
"name": "test_tool",
"description": "A test tool",
"inputSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
},
},
},
}
case "tools/call":
response["result"] = map[string]interface{}{
"content": []interface{}{
map[string]interface{}{
"type": "text",
"text": "Test response",
},
},
}
default:
response["error"] = map[string]interface{}{
"code": -32601,
"message": "Method not found",
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
}
func (s *HeaderCaptureServer) GetRequests() []RequestCapture {
s.mu.Lock()
defer s.mu.Unlock()
return append([]RequestCapture{}, s.requests...)
}
func TestMCPHeaderTransmission(t *testing.T) {
// Skip if Claude CLI is not available
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("Claude CLI not available")
}
// Create header capture server
captureServer := &HeaderCaptureServer{}
// Start HTTP server
gin.SetMode(gin.ReleaseMode)
router := gin.New()
router.Any("/mcp", gin.WrapF(captureServer.Handler()))
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
baseURL := fmt.Sprintf("http://127.0.0.1:%d", port)
server := &http.Server{Handler: router}
go server.Serve(listener)
defer server.Shutdown(context.Background())
// Create MCP config file with custom headers
tempDir := t.TempDir()
mcpConfigPath := fmt.Sprintf("%s/mcp-config.json", tempDir)
testSessionID := "test-session-123"
mcpConfig := map[string]interface{}{
"mcpServers": map[string]interface{}{
"test": map[string]interface{}{
"type": "http",
"url": fmt.Sprintf("%s/mcp", baseURL),
"headers": map[string]string{
"X-Session-ID": testSessionID,
"X-Custom-Header": "custom-value",
"Authorization": "Bearer test-token",
},
},
},
}
configBytes, _ := json.MarshalIndent(mcpConfig, "", " ")
err = os.WriteFile(mcpConfigPath, configBytes, 0644)
require.NoError(t, err)
t.Logf("MCP Config:\n%s", string(configBytes))
// Launch Claude with MCP config
cmd := exec.Command("claude",
"--print", "test",
"--mcp-config", mcpConfigPath,
"--max-turns", "1",
"--model", "sonnet",
"--output-format", "json",
)
cmd.Dir = tempDir
output, err := cmd.CombinedOutput()
t.Logf("Claude output:\n%s", string(output))
// Allow some time for any async operations
time.Sleep(500 * time.Millisecond)
// Analyze captured requests
requests := captureServer.GetRequests()
t.Logf("\n=== Captured %d MCP Requests ===", len(requests))
sessionIDFound := false
customHeaderFound := false
authHeaderFound := false
for i, req := range requests {
t.Logf("\nRequest #%d: %s %s", i+1, req.Method, req.Path)
// Check for our custom headers
if sessionIDs, ok := req.Headers["X-Session-Id"]; ok && len(sessionIDs) > 0 {
sessionIDFound = true
sessionID := sessionIDs[0]
t.Logf(" ✓ X-Session-ID: %s", sessionID)
if sessionID != testSessionID {
t.Errorf(" ✗ Session ID mismatch: got %s, want %s", sessionID, testSessionID)
}
}
if customs, ok := req.Headers["X-Custom-Header"]; ok && len(customs) > 0 {
customHeaderFound = true
t.Logf(" ✓ X-Custom-Header: %s", customs[0])
}
if auths, ok := req.Headers["Authorization"]; ok && len(auths) > 0 {
authHeaderFound = true
t.Logf(" ✓ Authorization: %s", auths[0])
}
// Log all headers for debugging
t.Log(" All headers:")
for k, v := range req.Headers {
t.Logf(" %s: %s", k, strings.Join(v, ", "))
}
}
// Verdict
t.Log("\n=== VERDICT ===")
if sessionIDFound && customHeaderFound && authHeaderFound {
t.Log("✓ All custom headers are transmitted correctly")
t.Log("✓ Session ID correlation via headers WILL work")
} else {
t.Log("✗ Custom headers are NOT being transmitted")
if !sessionIDFound {
t.Error("X-Session-ID header not found in MCP requests")
}
if !customHeaderFound {
t.Error("X-Custom-Header not found in MCP requests")
}
if !authHeaderFound {
t.Error("Authorization header not found in MCP requests")
}
t.Log("\n=== RECOMMENDATIONS ===")
t.Log("1. Embed session ID in the MCP server URL path")
t.Log("2. Use unique MCP server instances per session")
t.Log("3. Implement session correlation via MCP protocol messages")
}
}
func TestMCPSessionCorrelationAlternatives(t *testing.T) {
t.Log("\n=== Alternative Session Correlation Methods ===")
t.Log("\n1. URL Path Embedding:")
t.Log(" - Change MCP endpoint to: /api/v1/mcp/:session_id")
t.Log(" - Extract session ID from URL path in handler")
t.Log(" - Pro: Simple, reliable, no header dependency")
t.Log(" - Con: Requires URL generation per session")
t.Log("\n2. MCP Protocol Session:")
t.Log(" - Use MCP's initialize response to establish session")
t.Log(" - Store mapping: mcp_session_id -> humanlayer_session_id")
t.Log(" - Pro: Protocol-native solution")
t.Log(" - Con: Requires stateful MCP server")
t.Log("\n3. Token-based Correlation:")
t.Log(" - Generate unique token per session")
t.Log(" - Pass token in MCP server name or URL")
t.Log(" - Pro: Secure, unique per session")
t.Log(" - Con: Token management complexity")
t.Log("\n4. Process-based Correlation:")
t.Log(" - Track Claude process ID")
t.Log(" - Map process to session at launch")
t.Log(" - Pro: OS-level tracking")
t.Log(" - Con: Complex, platform-specific")
}