diff --git a/.claude/commands/local_review.md b/.claude/commands/local_review.md index e585aef..d74ab93 100644 --- a/.claude/commands/local_review.md +++ b/.claude/commands/local_review.md @@ -26,12 +26,7 @@ When invoked with a parameter like `gh_username:branchName`: - Run setup: `make -C WORKTREE setup` - Initialize thoughts: `cd WORKTREE && npx humanlayer thoughts init --directory humanlayer` -5. **Build CodeLayer in background**: - - Run: `make -C WORKTREE codelayer-dev &` - - This builds CodeLayer in the background so it's ready when needed - - If port 1420 is already in use, run with a different port: `VITE_PORT=1421 make -C WORKTREE codelayer-dev &` - -6. **Launch Claude Code session**: +5. **Launch Claude Code session**: - If a ticket number was found: `npx humanlayer launch --model opus -w WORKTREE "We're working on ENG-XXXX - fetch the issue and then await further instructions"` - Otherwise: `npx humanlayer launch --model opus -w WORKTREE "We're reviewing the branch BRANCHNAME - please familiarize yourself with the changes and await further instructions"` @@ -40,7 +35,6 @@ When invoked with a parameter like `gh_username:branchName`: - If worktree already exists, inform the user they need to remove it first - If remote fetch fails, check if the username/repo exists - If setup fails, provide the error but continue with the launch -- If CodeLayer fails with "Port 1420 is already in use", use the VITE_PORT alternative command ## Example Usage @@ -52,5 +46,4 @@ This will: - Add 'samdickson22' as a remote - Create worktree at `~/wt/humanlayer/eng-1696` - Set up the environment -- Build CodeLayer in the background - Launch Claude with: "We're working on ENG-1696 - fetch the issue and then await further instructions" diff --git a/claudecode-go/types.go b/claudecode-go/types.go index b4f9617..b06c5d1 100644 --- a/claudecode-go/types.go +++ b/claudecode-go/types.go @@ -27,17 +27,10 @@ const ( ) // MCPServer represents a single MCP server configuration -// It can be either a stdio-based server (with command/args/env) or an HTTP server (with type/url/headers) type MCPServer struct { - // For stdio-based servers - Command string `json:"command,omitempty"` + Command string `json:"command"` Args []string `json:"args,omitempty"` Env map[string]string `json:"env,omitempty"` - - // For HTTP servers - Type string `json:"type,omitempty"` // "http" for HTTP servers - URL string `json:"url,omitempty"` // The HTTP endpoint URL - Headers map[string]string `json:"headers,omitempty"` // HTTP headers to include } // MCPConfig represents the MCP configuration structure diff --git a/hld/CLAUDE.md b/hld/CLAUDE.md index df91d96..cb31dba 100644 --- a/hld/CLAUDE.md +++ b/hld/CLAUDE.md @@ -6,7 +6,7 @@ The daemon logs are in ~/.humanlayer/logs/daemon-*.log (timestamped files create WUI logs (which include daemon stderr output) are in: - Development: `~/.humanlayer/logs/wui-{branch}/codelayer.log` -- Production: Platform-specific log directories, e.g. ~/Library/Logs/dev.humanlayer.wui.nightly/CodeLayer-Nightly.log +- Production: Platform-specific log directories It uses a database at ~/.humanlayer/*.db - you can access it with sqlite3 to inspect progress and debug things. @@ -47,9 +47,3 @@ echo '{"jsonrpc":"2.0","method":"getSessionLeaves","params":{},"id":1}' | nc -U For testing guidelines and database isolation requirements, see TESTING.md - - -### Go style guidelines - -- any async or long-running goroutine should accept a context.Context as a parameter and handle cancellation gracefully -- context and CancelFuncs should never be stored on structs, always passed as the first parameter to a function diff --git a/hld/api/handlers/sessions_test.go b/hld/api/handlers/sessions_test.go index d0c4e4f..64eaf2d 100644 --- a/hld/api/handlers/sessions_test.go +++ b/hld/api/handlers/sessions_test.go @@ -89,7 +89,7 @@ func TestSessionHandlers_CreateSession(t *testing.T) { McpConfig: &api.MCPConfig{ McpServers: &map[string]api.MCPServer{ "test-server": { - Command: stringPtr("node"), + Command: "node", Args: &[]string{"server.js"}, Env: &map[string]string{ "DEBUG": "true", diff --git a/hld/api/mapper/mapper.go b/hld/api/mapper/mapper.go index 4f42719..4957275 100644 --- a/hld/api/mapper/mapper.go +++ b/hld/api/mapper/mapper.go @@ -201,22 +201,8 @@ func (m *Mapper) MCPConfigFromAPI(config *api.MCPConfig) *claudecode.MCPConfig { servers := make(map[string]claudecode.MCPServer) if config.McpServers != nil { for name, server := range *config.McpServers { - mcpServer := claudecode.MCPServer{} - - // Map HTTP server fields - if server.Type != nil { - mcpServer.Type = *server.Type - } - if server.Url != nil { - mcpServer.URL = *server.Url - } - if server.Headers != nil { - mcpServer.Headers = *server.Headers - } - - // Map stdio server fields - if server.Command != nil { - mcpServer.Command = *server.Command + mcpServer := claudecode.MCPServer{ + Command: server.Command, } if server.Args != nil { mcpServer.Args = *server.Args @@ -224,7 +210,6 @@ func (m *Mapper) MCPConfigFromAPI(config *api.MCPConfig) *claudecode.MCPConfig { if server.Env != nil { mcpServer.Env = *server.Env } - servers[name] = mcpServer } } diff --git a/hld/api/openapi.yaml b/hld/api/openapi.yaml index 10b824a..34cca9c 100644 --- a/hld/api/openapi.yaml +++ b/hld/api/openapi.yaml @@ -1242,39 +1242,26 @@ components: MCPServer: type: object + required: + - command properties: - type: - type: string - description: Server type (http for HTTP servers, omit for stdio) - example: http command: type: string - description: Command to execute (for stdio servers) + description: Command to execute example: mcp-server-filesystem args: type: array items: type: string - description: Command arguments (for stdio servers) + description: Command arguments example: ["--read-only", "/home/user"] env: type: object additionalProperties: type: string - description: Environment variables (for stdio servers) + description: Environment variables example: DEBUG: "true" - url: - type: string - description: HTTP endpoint URL (for HTTP servers) - example: http://localhost:7777/api/v1/mcp - headers: - type: object - additionalProperties: - type: string - description: HTTP headers to include (for HTTP servers) - example: - X-Session-ID: "session-123" # Event Types EventType: diff --git a/hld/api/server.gen.go b/hld/api/server.gen.go index 6303c04..d8b46e8 100644 --- a/hld/api/server.gen.go +++ b/hld/api/server.gen.go @@ -447,23 +447,14 @@ type MCPConfig struct { // MCPServer defines model for MCPServer. type MCPServer struct { - // Args Command arguments (for stdio servers) + // Args Command arguments Args *[]string `json:"args,omitempty"` - // Command Command to execute (for stdio servers) - Command *string `json:"command,omitempty"` + // Command Command to execute + Command string `json:"command"` - // Env Environment variables (for stdio servers) + // Env Environment variables Env *map[string]string `json:"env,omitempty"` - - // Headers HTTP headers to include (for HTTP servers) - Headers *map[string]string `json:"headers,omitempty"` - - // Type Server type (http for HTTP servers, omit for stdio) - Type *string `json:"type,omitempty"` - - // Url HTTP endpoint URL (for HTTP servers) - Url *string `json:"url,omitempty"` } // RecentPath defines model for RecentPath. diff --git a/hld/approval/manager.go b/hld/approval/manager.go index 05b6784..72a8e72 100644 --- a/hld/approval/manager.go +++ b/hld/approval/manager.go @@ -232,7 +232,7 @@ func (m *manager) correlateApproval(ctx context.Context, approval *store.Approva } // Correlate by tool ID - if err := m.store.LinkConversationEventToApprovalUsingToolID(ctx, approval.SessionID, toolCall.ToolID, approval.ID); err != nil { + if err := m.store.CorrelateApprovalByToolID(ctx, approval.SessionID, toolCall.ToolID, approval.ID); err != nil { return fmt.Errorf("failed to correlate approval: %w", err) } @@ -258,20 +258,15 @@ func (m *manager) publishNewApprovalEvent(approval *store.Approval) { // publishApprovalResolvedEvent publishes an event when an approval is resolved func (m *manager) publishApprovalResolvedEvent(approval *store.Approval, approved bool, responseText string) { if m.eventBus != nil { - eventData := map[string]interface{}{ - "approval_id": approval.ID, - "session_id": approval.SessionID, - "approved": approved, - "response_text": responseText, - } - // Include tool_use_id if present - if approval.ToolUseID != nil { - eventData["tool_use_id"] = *approval.ToolUseID - } event := bus.Event{ Type: bus.EventApprovalResolved, Timestamp: time.Now(), - Data: eventData, + Data: map[string]interface{}{ + "approval_id": approval.ID, + "session_id": approval.SessionID, + "approved": approved, + "response_text": responseText, + }, } m.eventBus.Publish(event) } @@ -286,109 +281,6 @@ func (m *manager) updateSessionStatus(ctx context.Context, sessionID, status str return m.store.UpdateSession(ctx, sessionID, updates) } -// CreateApprovalWithToolUseID creates an approval with tool_use_id field -func (m *manager) CreateApprovalWithToolUseID(ctx context.Context, sessionID, toolName string, toolInput json.RawMessage, toolUseID string) (*store.Approval, error) { - // Check if auto-accept is enabled (either mode) - session, err := m.store.GetSession(ctx, sessionID) - if err != nil { - return nil, fmt.Errorf("failed to get session: %w", err) - } - if session == nil { - return nil, fmt.Errorf("session not found: %s", sessionID) - } - - status := store.ApprovalStatusLocalPending - comment := "" - - // Check dangerously skip permissions first (overrides edit mode) - if session.DangerouslySkipPermissions { - // Check if it has an expiry and if it's expired - if session.DangerouslySkipPermissionsExpiresAt != nil && time.Now().After(*session.DangerouslySkipPermissionsExpiresAt) { - // Expired - disable it - update := store.SessionUpdate{ - DangerouslySkipPermissions: &[]bool{false}[0], - DangerouslySkipPermissionsExpiresAt: &[]*time.Time{nil}[0], - } - if err := m.store.UpdateSession(ctx, session.ID, update); err != nil { - slog.Error("failed to disable expired dangerously skip permissions", "session_id", session.ID, "error", err) - } - // Continue with normal approval - } else { - // Dangerously skip permissions is active (no expiry or not expired) - status = store.ApprovalStatusLocalApproved - comment = "Auto-accepted (dangerous skip permissions enabled)" - } - } else if session.AutoAcceptEdits && isEditTool(toolName) { - // Regular auto-accept edits mode - status = store.ApprovalStatusLocalApproved - comment = "Auto-accepted (auto-accept mode enabled)" - } - - // Create approval with tool_use_id - approval := &store.Approval{ - ID: "local-" + uuid.New().String(), - RunID: session.RunID, - SessionID: sessionID, - ToolUseID: &toolUseID, - Status: status, - CreatedAt: time.Now(), - ToolName: toolName, - ToolInput: toolInput, - Comment: comment, - } - - // Store it - if err := m.store.CreateApproval(ctx, approval); err != nil { - return nil, fmt.Errorf("failed to store approval: %w", err) - } - - // Publish event for real-time updates - m.publishNewApprovalEvent(approval) - - if err := m.store.LinkConversationEventToApprovalUsingToolID(ctx, sessionID, toolUseID, approval.ID); err != nil { - // Log but don't fail - // TODO(1): Don't ship if above LinkConversationEventToApprovalUsingToolID does not retry - // it's possible, albeit unlikely, that the raw_event has not made it to - // conversation_events yet - return nil, fmt.Errorf("failed to correlate approval: %w", err) - } - - // Handle status-specific post-creation tasks - switch status { - case store.ApprovalStatusLocalPending: - // Update session status to waiting_input for pending approvals - if err := m.updateSessionStatus(ctx, session.ID, store.SessionStatusWaitingInput); err != nil { - slog.Warn("failed to update session status", - "error", err, - "session_id", session.ID) - } - case store.ApprovalStatusLocalApproved: - // For auto-approved, update correlation status immediately - // Update approval status - if err := m.store.UpdateApprovalStatus(ctx, approval.ID, store.ApprovalStatusApproved); err != nil { - slog.Warn("failed to update approval status in conversation events", - "error", err, - "approval_id", approval.ID) - } - // Publish resolved event for auto-approved - m.publishApprovalResolvedEvent(approval, true, comment) - } - - logLevel := slog.LevelInfo - if status == store.ApprovalStatusLocalApproved { - logLevel = slog.LevelDebug // Less noise for auto-approved - } - slog.Log(ctx, logLevel, "created approval with tool_use_id", - "approval_id", approval.ID, - "session_id", sessionID, - "tool_name", toolName, - "tool_use_id", toolUseID, - "status", status, - "auto_accepted", status == store.ApprovalStatusLocalApproved) - - return approval, nil -} - // isEditTool checks if a tool name is one of the edit tools func isEditTool(toolName string) bool { return toolName == "Edit" || toolName == "Write" || toolName == "MultiEdit" diff --git a/hld/approval/manager_test.go b/hld/approval/manager_test.go index 5692b87..f89583b 100644 --- a/hld/approval/manager_test.go +++ b/hld/approval/manager_test.go @@ -248,7 +248,7 @@ func TestManager_CorrelateApproval(t *testing.T) { mockStore.EXPECT().GetUncorrelatedPendingToolCall(ctx, sessionID, toolName).Return(pendingToolCall, nil) // Mock correlating by tool ID - mockStore.EXPECT().LinkConversationEventToApprovalUsingToolID(ctx, sessionID, "tool-123", gomock.Any()).Return(nil) + mockStore.EXPECT().CorrelateApprovalByToolID(ctx, sessionID, "tool-123", gomock.Any()).Return(nil) // Mock event publishing mockEventBus.EXPECT().Publish(gomock.Any()) diff --git a/hld/approval/types.go b/hld/approval/types.go index 3a40cbc..e68b599 100644 --- a/hld/approval/types.go +++ b/hld/approval/types.go @@ -12,9 +12,6 @@ type Manager interface { // Create a new approval CreateApproval(ctx context.Context, runID, toolName string, toolInput json.RawMessage) (string, error) - // Create approval with tool_use_id (Phase 4) - CreateApprovalWithToolUseID(ctx context.Context, sessionID, toolName string, toolInput json.RawMessage, toolUseID string) (*store.Approval, error) - // Retrieval methods GetPendingApprovals(ctx context.Context, sessionID string) ([]*store.Approval, error) GetApproval(ctx context.Context, id string) (*store.Approval, error) diff --git a/hld/daemon/daemon_subscription_integration_test.go b/hld/daemon/daemon_subscription_integration_test.go index a2d295c..7370f58 100644 --- a/hld/daemon/daemon_subscription_integration_test.go +++ b/hld/daemon/daemon_subscription_integration_test.go @@ -71,10 +71,8 @@ func TestDaemonSubscriptionIntegration(t *testing.T) { } // Verify all clients get subscriber count - // Note: MCP server adds 1 subscriber for listening to approval events - expectedSubscribers := numClients + 1 - if subCount := daemon.eventBus.GetSubscriberCount(); subCount != expectedSubscribers { - t.Errorf("Expected %d subscribers (including MCP listener), got %d", expectedSubscribers, subCount) + if subCount := daemon.eventBus.GetSubscriberCount(); subCount != numClients { + t.Errorf("Expected %d subscribers, got %d", numClients, subCount) } // Publish an event @@ -268,8 +266,6 @@ func TestDaemonMemoryStability(t *testing.T) { socketPath := testutil.CreateTestSocket(t) t.Setenv("HUMANLAYER_SOCKET_PATH", socketPath) t.Setenv("HUMANLAYER_LOG_LEVEL", "error") - // Use in-memory database for tests - t.Setenv("HUMANLAYER_DATABASE_PATH", ":memory:") // Create and start daemon daemon, err := New() @@ -333,9 +329,9 @@ func TestDaemonMemoryStability(t *testing.T) { time.Sleep(10 * time.Millisecond) } - // Check final subscriber count (should be 1 for MCP listener) + // Check final subscriber count (should be 0) finalCount := daemon.eventBus.GetSubscriberCount() - if finalCount != 1 { - t.Errorf("Expected 1 subscriber (MCP listener) after all clients disconnected, got %d", finalCount) + if finalCount != 0 { + t.Errorf("Expected 0 subscribers after all clients disconnected, got %d", finalCount) } } diff --git a/hld/daemon/http_server.go b/hld/daemon/http_server.go index d4c707c..4f60f13 100644 --- a/hld/daemon/http_server.go +++ b/hld/daemon/http_server.go @@ -15,7 +15,6 @@ import ( "github.com/humanlayer/humanlayer/hld/approval" "github.com/humanlayer/humanlayer/hld/bus" "github.com/humanlayer/humanlayer/hld/config" - "github.com/humanlayer/humanlayer/hld/mcp" "github.com/humanlayer/humanlayer/hld/session" "github.com/humanlayer/humanlayer/hld/store" ) @@ -27,8 +26,6 @@ type HTTPServer struct { sessionHandlers *handlers.SessionHandlers approvalHandlers *handlers.ApprovalHandlers sseHandler *handlers.SSEHandler - approvalManager approval.Manager - eventBus bus.EventBus server *http.Server } @@ -74,8 +71,6 @@ func NewHTTPServer( sessionHandlers: sessionHandlers, approvalHandlers: approvalHandlers, sseHandler: sseHandler, - approvalManager: approvalManager, - eventBus: eventBus, } } @@ -96,13 +91,6 @@ func (s *HTTPServer) Start(ctx context.Context) error { // Register SSE endpoint directly (not part of strict interface) v1.GET("/stream/events", s.sseHandler.StreamEvents) - // MCP endpoint (Phase 5: with event-driven approvals) - mcpServer := mcp.NewMCPServer(s.approvalManager, s.eventBus) - mcpServer.Start(ctx) // Start background processes with context - v1.Any("/mcp", func(c *gin.Context) { - mcpServer.ServeHTTP(c.Writer, c.Request) - }) - // Create listener first to handle port 0 addr := fmt.Sprintf("%s:%d", s.config.HTTPHost, s.config.HTTPPort) listener, err := net.Listen("tcp", addr) diff --git a/hld/daemon/mcp_claudecode_integration_test.go b/hld/daemon/mcp_claudecode_integration_test.go deleted file mode 100644 index aae6227..0000000 --- a/hld/daemon/mcp_claudecode_integration_test.go +++ /dev/null @@ -1,443 +0,0 @@ -//go:build integration - -package daemon_test - -import ( - "bytes" - "context" - "database/sql" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "testing" - "time" - - "github.com/gin-gonic/gin" - claudecode "github.com/humanlayer/humanlayer/claudecode-go" - "github.com/humanlayer/humanlayer/hld/internal/testutil" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/require" -) - -// MCPRequestLog captures all details of an MCP request -type MCPRequestLog struct { - Method string - Path string - Headers map[string][]string - Body json.RawMessage - Timestamp time.Time - RequestID int -} - -// MCPTestServer wraps the real MCP server and logs all requests -type MCPTestServer struct { - realHandler http.Handler - requests []MCPRequestLog - requestsMutex sync.Mutex - requestCount int -} - -func (s *MCPTestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Capture request details - s.requestsMutex.Lock() - s.requestCount++ - requestID := s.requestCount - - // Read body for logging - bodyBytes, _ := io.ReadAll(r.Body) - r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for real handler - - // Log all headers - headersCopy := make(map[string][]string) - for k, v := range r.Header { - headersCopy[k] = v - } - - log := MCPRequestLog{ - Method: r.Method, - Path: r.URL.Path, - Headers: headersCopy, - Body: bodyBytes, - Timestamp: time.Now(), - RequestID: requestID, - } - s.requests = append(s.requests, log) - s.requestsMutex.Unlock() - - // Log request details to test output - fmt.Printf("\n[MCP Request #%d] %s %s\n", requestID, r.Method, r.URL.Path) - fmt.Printf("Headers:\n") - for k, v := range r.Header { - fmt.Printf(" %s: %s\n", k, strings.Join(v, ", ")) - } - if len(bodyBytes) > 0 { - fmt.Printf("Body: %s\n", string(bodyBytes)) - } - fmt.Printf("---\n") - - // Forward to real handler - s.realHandler.ServeHTTP(w, r) -} - -func (s *MCPTestServer) GetRequests() []MCPRequestLog { - s.requestsMutex.Lock() - defer s.requestsMutex.Unlock() - return append([]MCPRequestLog{}, s.requests...) -} - -func TestMCPClaudeCodeSessionIDCorrelation(t *testing.T) { - // Skip if Claude is not available - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("Claude CLI not available, skipping integration test") - } - - // Setup isolated environment - socketPath := testutil.SocketPath(t, "mcp-claudecode") - dbPath := testutil.DatabasePath(t, "mcp-claudecode") - - // Get a free port for HTTP server - httpPort := getFreePort(t) - - // Override environment - os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath) - os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort)) - os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1") - os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API - os.Setenv("MCP_AUTO_DENY_ALL", "true") // Auto-deny for predictable responses - - // Create isolated config - tempDir := t.TempDir() - os.Setenv("XDG_CONFIG_HOME", tempDir) - configDir := filepath.Join(tempDir, "humanlayer") - require.NoError(t, os.MkdirAll(configDir, 0755)) - configFile := filepath.Join(configDir, "humanlayer.json") - require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644)) - - // Create MCP test server wrapper - mcpTestServer := &MCPTestServer{ - requests: []MCPRequestLog{}, - } - - // Custom HTTP server setup to wrap MCP handler - gin.SetMode(gin.ReleaseMode) - router := gin.New() - - // Add health endpoint - router.GET("/api/v1/health", func(c *gin.Context) { - c.JSON(200, gin.H{"status": "ok"}) - }) - - // Wrap MCP endpoint with test server - router.Any("/api/v1/mcp", func(c *gin.Context) { - // First time setup - get real handler from daemon - if mcpTestServer.realHandler == nil { - // Get the real MCP handler from daemon - // We'll create a simple MCP handler that auto-denies - mcpTestServer.realHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Simple MCP response for testing - var req map[string]interface{} - json.NewDecoder(r.Body).Decode(&req) - - method, _ := req["method"].(string) - id := req["id"] - - response := map[string]interface{}{ - "jsonrpc": "2.0", - "id": id, - } - - switch method { - case "initialize": - response["result"] = map[string]interface{}{ - "protocolVersion": "2025-03-26", - "serverInfo": map[string]interface{}{ - "name": "test-mcp-server", - "version": "1.0.0", - }, - "capabilities": map[string]interface{}{ - "tools": map[string]interface{}{}, - }, - } - case "tools/list": - response["result"] = map[string]interface{}{ - "tools": []interface{}{ - map[string]interface{}{ - "name": "request_approval", - "description": "Request permission to execute a tool", - "inputSchema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "tool_name": map[string]string{"type": "string"}, - "input": map[string]string{"type": "object"}, - "tool_use_id": map[string]string{"type": "string"}, - }, - "required": []string{"tool_name", "input", "tool_use_id"}, - }, - }, - }, - } - case "tools/call": - // Auto-deny - response["result"] = map[string]interface{}{ - "content": []interface{}{ - map[string]interface{}{ - "type": "text", - "text": `{"behavior": "deny", "message": "Auto-denied for testing"}`, - }, - }, - } - default: - response["error"] = map[string]interface{}{ - "code": -32601, - "message": "Method not found", - } - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) - }) - } - - mcpTestServer.ServeHTTP(c.Writer, c.Request) - }) - - // Start HTTP server - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", httpPort)) - require.NoError(t, err) - - server := &http.Server{ - Handler: router, - } - - go func() { - server.Serve(listener) - }() - defer server.Shutdown(context.Background()) - - // Wait for HTTP server to be ready - baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort) - require.Eventually(t, func() bool { - resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL)) - if err == nil { - resp.Body.Close() - return resp.StatusCode == 200 - } - return false - }, 5*time.Second, 100*time.Millisecond, "HTTP server did not start") - - // Open database connection - db, err := sql.Open("sqlite3", dbPath) - require.NoError(t, err) - defer db.Close() - - // Create a test session in the database - testSessionID := "test-claudecode-session" - _, err = db.Exec(` - INSERT INTO sessions ( - id, run_id, claude_session_id, query, model, working_dir, - status, created_at, last_activity_at, auto_accept_edits, - dangerously_skip_permissions, max_turns, system_prompt, - custom_instructions, cost_usd, input_tokens, output_tokens, - duration_ms, num_turns, result_content, error_message - ) VALUES ( - ?, 'run-claudecode', 'claude-test', 'test query', 'claude-3-sonnet', '/tmp', - 'running', datetime('now'), datetime('now'), 0, 1, 10, '', - '', 0.0, 0, 0, 0, 0, '', '' - ) - `, testSessionID) - require.NoError(t, err) - - // Create claudecode client - client, err := claudecode.NewClient() - require.NoError(t, err) - - // Prepare MCP configuration - // The claudecode client will write this to a temp file and pass it to claude - // We need to match the format claude expects for HTTP MCP servers - mcpConfig := &claudecode.MCPConfig{ - MCPServers: map[string]claudecode.MCPServer{ - "humanlayer": { - Command: "http", // This is just a placeholder - Env: map[string]string{ - // The actual config will be written as JSON with type/url/headers - "_config": fmt.Sprintf(`{"type":"http","url":"%s/api/v1/mcp","headers":{"X-Session-ID":"%s"}}`, baseURL, testSessionID), - }, - }, - }, - } - - // Create session config - sessionConfig := claudecode.SessionConfig{ - Query: "Say 'test complete' and exit", - Model: claudecode.ModelSonnet, - OutputFormat: claudecode.OutputStreamJSON, - MCPConfig: mcpConfig, - PermissionPromptTool: "mcp__humanlayer__request_approval", - MaxTurns: 1, - WorkingDir: tempDir, - Verbose: true, - } - - // Capture events from Claude - var allEvents []claudecode.StreamEvent - var eventsMutex sync.Mutex - - // Launch Claude session - t.Log("Launching Claude session with MCP config...") - session, err := client.Launch(sessionConfig) - require.NoError(t, err) - - // Capture events in background - eventsDone := make(chan struct{}) - go func() { - defer close(eventsDone) - for event := range session.Events { - eventsMutex.Lock() - allEvents = append(allEvents, event) - eventsMutex.Unlock() - - // Log significant events - switch event.Type { - case "system": - if event.Subtype == "init" { - t.Logf("Claude session initialized: ID=%s, Model=%s", event.SessionID, event.Model) - } - case "mcp_servers": - for _, server := range event.MCPServers { - t.Logf("MCP Server %s: %s", server.Name, server.Status) - } - case "result": - t.Logf("Session completed: ID=%s, Error=%v", event.SessionID, event.IsError) - } - } - }() - - // Wait for session to complete (with timeout) - done := make(chan struct{}) - go func() { - defer close(done) - result, err := session.Wait() - if err != nil { - t.Logf("Session error: %v", err) - } else if result != nil { - t.Logf("Session result: %s", result.Result) - } - }() - - select { - case <-done: - // Session completed - case <-time.After(30 * time.Second): - t.Log("Session timeout, interrupting...") - session.Interrupt() - <-done - } - - // Wait for events to be processed - <-eventsDone - - // Analyze captured MCP requests - requests := mcpTestServer.GetRequests() - t.Logf("\n=== MCP Request Analysis ===") - t.Logf("Total MCP requests: %d", len(requests)) - - // Check for session ID in headers - sessionIDFound := false - var sessionIDHeaders []string - - for i, req := range requests { - t.Logf("\nRequest #%d: %s", i+1, req.Method) - - // Check various possible session ID headers - possibleHeaders := []string{ - "X-Session-ID", - "X-Session-Id", - "Session-ID", - "Session-Id", - "Mcp-Session-Id", - "MCP-Session-ID", - } - - for _, header := range possibleHeaders { - if values, ok := req.Headers[header]; ok && len(values) > 0 { - sessionIDFound = true - sessionIDHeaders = append(sessionIDHeaders, fmt.Sprintf("%s: %s", header, values[0])) - t.Logf(" ✓ Found session ID header: %s = %s", header, values[0]) - } - } - - // Check if session ID is in the request body - if len(req.Body) > 0 { - var body map[string]interface{} - if err := json.Unmarshal(req.Body, &body); err == nil { - if sessionID, ok := body["session_id"].(string); ok && sessionID != "" { - t.Logf(" ✓ Found session_id in body: %s", sessionID) - } - if params, ok := body["params"].(map[string]interface{}); ok { - if sessionID, ok := params["session_id"].(string); ok && sessionID != "" { - t.Logf(" ✓ Found session_id in params: %s", sessionID) - } - } - } - } - } - - // Analyze Claude events for session information - t.Logf("\n=== Claude Event Analysis ===") - t.Logf("Total events captured: %d", len(allEvents)) - - var claudeSessionID string - for _, event := range allEvents { - if event.SessionID != "" && claudeSessionID == "" { - claudeSessionID = event.SessionID - t.Logf("Claude session ID from events: %s", claudeSessionID) - } - } - - // Final verdict - t.Logf("\n=== VERDICT ===") - if sessionIDFound { - t.Logf("✓ Session ID IS sent in MCP request headers") - t.Logf(" Headers found: %s", strings.Join(sessionIDHeaders, ", ")) - t.Logf(" Current implementation should work correctly") - } else { - t.Logf("✗ Session ID is NOT sent in MCP request headers") - t.Logf(" Claude session ID: %s", claudeSessionID) - t.Logf(" Need to implement alternative correlation mechanism") - t.Logf(" Possible solutions:") - t.Logf(" 1. Use MCP session initialization to establish mapping") - t.Logf(" 2. Pass session ID in MCP server URL path") - t.Logf(" 3. Use a unique MCP server per session") - } - - // Assert findings - if !sessionIDFound { - t.Error("Session ID is not being sent in MCP request headers - implementation needs revision") - - // Provide detailed recommendations - t.Log("\nRECOMMENDED CHANGES:") - t.Log("1. Remove reliance on Session-ID header in MCP server") - t.Log("2. Consider embedding session ID in MCP server URL:") - t.Log(" - Change URL to: /api/v1/mcp/{session_id}") - t.Log(" - Extract session ID from URL path in handler") - t.Log("3. Or use MCP session correlation:") - t.Log(" - Track MCP session ID from initialize method") - t.Log(" - Map MCP session to HumanLayer session") - } - - // Additional diagnostics - t.Logf("\n=== Additional Diagnostics ===") - if len(requests) > 0 { - t.Log("First request headers:") - for k, v := range requests[0].Headers { - t.Logf(" %s: %s", k, strings.Join(v, ", ")) - } - } -} diff --git a/hld/daemon/mcp_integration_test.go b/hld/daemon/mcp_integration_test.go deleted file mode 100644 index 92230b4..0000000 --- a/hld/daemon/mcp_integration_test.go +++ /dev/null @@ -1,290 +0,0 @@ -//go:build integration -// +build integration - -package daemon_test - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net" - "net/http" - "os" - "path/filepath" - "testing" - "time" - - "github.com/humanlayer/humanlayer/hld/daemon" - "github.com/humanlayer/humanlayer/hld/internal/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMCPStubEndpoint(t *testing.T) { - // Setup isolated environment - socketPath := testutil.SocketPath(t, "mcp") - _ = testutil.DatabasePath(t, "mcp") // Sets HUMANLAYER_DATABASE_PATH - - // Get a free port for HTTP server - httpPort := getFreePort(t) - - // Override environment - os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath) - os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort)) - os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1") - os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API - - // Create isolated config - tempDir := t.TempDir() - os.Setenv("XDG_CONFIG_HOME", tempDir) - configDir := filepath.Join(tempDir, "humanlayer") - require.NoError(t, os.MkdirAll(configDir, 0755)) - configFile := filepath.Join(configDir, "humanlayer.json") - require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644)) - - // Create daemon - d, err := daemon.New() - require.NoError(t, err, "Failed to create daemon") - - // Start daemon in background - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCh := make(chan error, 1) - go func() { - errCh <- d.Run(ctx) - }() - - // Wait for HTTP server to be ready - baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort) - require.Eventually(t, func() bool { - resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL)) - if err == nil { - resp.Body.Close() - return resp.StatusCode == 200 - } - return false - }, 5*time.Second, 100*time.Millisecond, "HTTP server did not start") - - t.Run("Initialize", func(t *testing.T) { - // Test MCP initialize method - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": map[string]interface{}{ - "protocolVersion": "2025-03-26", - "capabilities": map[string]interface{}{}, - "clientInfo": map[string]interface{}{ - "name": "test", - "version": "1.0", - }, - }, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - resp, err := http.Post( - fmt.Sprintf("%s/api/v1/mcp", baseURL), - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode) - - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Verify response structure - assert.Equal(t, "2.0", result["jsonrpc"]) - assert.Equal(t, float64(1), result["id"]) - - // Check result field - res, ok := result["result"].(map[string]interface{}) - require.True(t, ok, "result field should be a map") - - assert.Equal(t, "2025-03-26", res["protocolVersion"]) - - serverInfo, ok := res["serverInfo"].(map[string]interface{}) - require.True(t, ok, "serverInfo should be a map") - assert.Equal(t, "humanlayer-daemon", serverInfo["name"]) - assert.Equal(t, "1.0.0", serverInfo["version"]) - }) - - t.Run("ToolsList", func(t *testing.T) { - // Test tools/list method - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/list", - "params": map[string]interface{}{}, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - resp, err := http.Post( - fmt.Sprintf("%s/api/v1/mcp", baseURL), - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode) - - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Check tools list - res, ok := result["result"].(map[string]interface{}) - require.True(t, ok) - - tools, ok := res["tools"].([]interface{}) - require.True(t, ok) - assert.Len(t, tools, 1) - - tool := tools[0].(map[string]interface{}) - assert.Equal(t, "request_approval", tool["name"]) - assert.Contains(t, tool["description"], "Request permission to execute a tool") - }) - - t.Run("UnknownMethod", func(t *testing.T) { - // Test unknown method - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 4, - "method": "unknown/method", - "params": map[string]interface{}{}, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - resp, err := http.Post( - fmt.Sprintf("%s/api/v1/mcp", baseURL), - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode) - - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Should have error response - errResp, ok := result["error"].(map[string]interface{}) - require.True(t, ok, "Should have error field") - - assert.Equal(t, float64(-32601), errResp["code"]) - assert.Contains(t, errResp["message"], "not found") - }) - - t.Run("AutoDeny", func(t *testing.T) { - // Set auto-deny mode - os.Setenv("MCP_AUTO_DENY_ALL", "true") - defer os.Unsetenv("MCP_AUTO_DENY_ALL") - - // Restart daemon with auto-deny - cancel() - - // Wait for shutdown - select { - case <-errCh: - case <-time.After(2 * time.Second): - t.Fatal("Daemon did not shut down") - } - - // Create new daemon with auto-deny - d2, err := daemon.New() - require.NoError(t, err) - - ctx2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() - - errCh2 := make(chan error, 1) - go func() { - errCh2 <- d2.Run(ctx2) - }() - - // Wait for server to be ready again - require.Eventually(t, func() bool { - resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL)) - if err == nil { - resp.Body.Close() - return resp.StatusCode == 200 - } - return false - }, 5*time.Second, 100*time.Millisecond) - - // Test tools/call with auto-deny - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 3, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - "tool_name": "test_tool", - "input": map[string]interface{}{"test": "data"}, - "tool_use_id": "test_123", - }, - }, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - resp, err := http.Post( - fmt.Sprintf("%s/api/v1/mcp", baseURL), - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode) - - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Check auto-deny response - res, ok := result["result"].(map[string]interface{}) - require.True(t, ok) - - content, ok := res["content"].([]interface{}) - require.True(t, ok) - require.Len(t, content, 1) - - contentItem := content[0].(map[string]interface{}) - assert.Equal(t, "text", contentItem["type"]) - - // Parse the JSON text content - text := contentItem["text"].(string) - var approval map[string]interface{} - err = json.Unmarshal([]byte(text), &approval) - require.NoError(t, err) - - assert.Equal(t, "deny", approval["behavior"]) - assert.Contains(t, approval["message"], "Auto-denied") - }) -} - -// getFreePort gets a free TCP port for testing -func getFreePort(t *testing.T) int { - listener, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer listener.Close() - - return listener.Addr().(*net.TCPAddr).Port -} diff --git a/hld/daemon/mcp_phase4_integration_test.go b/hld/daemon/mcp_phase4_integration_test.go deleted file mode 100644 index b2bb662..0000000 --- a/hld/daemon/mcp_phase4_integration_test.go +++ /dev/null @@ -1,404 +0,0 @@ -//go:build integration - -package daemon_test - -import ( - "bytes" - "context" - "database/sql" - "encoding/json" - "fmt" - "net/http" - "os" - "path/filepath" - "testing" - "time" - - "github.com/humanlayer/humanlayer/hld/daemon" - "github.com/humanlayer/humanlayer/hld/internal/testutil" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMCPPhase4ApprovalCreation(t *testing.T) { - // Setup isolated environment - socketPath := testutil.SocketPath(t, "mcp-phase4") - dbPath := testutil.DatabasePath(t, "mcp-phase4") - - // Get a free port for HTTP server - httpPort := getFreePort(t) - - // Override environment - os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath) - os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort)) - os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1") - os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API - - // Create isolated config - tempDir := t.TempDir() - os.Setenv("XDG_CONFIG_HOME", tempDir) - configDir := filepath.Join(tempDir, "humanlayer") - require.NoError(t, os.MkdirAll(configDir, 0755)) - configFile := filepath.Join(configDir, "humanlayer.json") - require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644)) - - // Create daemon - d, err := daemon.New() - require.NoError(t, err, "Failed to create daemon") - - // Start daemon in background - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCh := make(chan error, 1) - go func() { - errCh <- d.Run(ctx) - }() - - // Wait for HTTP server to be ready - baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort) - require.Eventually(t, func() bool { - resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL)) - if err == nil { - resp.Body.Close() - return resp.StatusCode == 200 - } - return false - }, 5*time.Second, 100*time.Millisecond, "HTTP server did not start") - - // Open database connection - db, err := sql.Open("sqlite3", dbPath) - require.NoError(t, err) - defer db.Close() - - // Create a test session - sessionID := "test-session-phase4" - _, err = db.Exec(` - INSERT INTO sessions ( - id, run_id, claude_session_id, query, model, working_dir, - status, created_at, last_activity_at, auto_accept_edits, - dangerously_skip_permissions, max_turns, system_prompt, - custom_instructions, cost_usd, input_tokens, output_tokens, - duration_ms, num_turns, result_content, error_message - ) VALUES ( - ?, 'run-phase4', 'claude-phase4', 'test query', 'claude-3-sonnet', '/tmp', - 'running', datetime('now'), datetime('now'), 0, 0, 10, '', - '', 0.0, 0, 0, 0, 0, '', '' - ) - `, sessionID) - require.NoError(t, err) - - t.Run("ApprovalCreatedWithToolUseID", func(t *testing.T) { - // Send MCP approval request - toolUseID := "test_use_phase4_123" - req := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - "tool_name": "test_tool", - "input": map[string]interface{}{"test": "data"}, - "tool_use_id": toolUseID, - }, - }, - } - - body, _ := json.Marshal(req) - httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body)) - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-Session-ID", sessionID) - - // Send request in background (it will block waiting for approval) - go func() { - client := &http.Client{Timeout: 2 * time.Second} - client.Do(httpReq) - }() - - // Wait for approval to be created - time.Sleep(500 * time.Millisecond) - - // Check database for approval with tool_use_id - var count int - err := db.QueryRow(` - SELECT COUNT(*) FROM approvals - WHERE tool_use_id = ? AND session_id = ? - `, toolUseID, sessionID).Scan(&count) - require.NoError(t, err) - assert.Equal(t, 1, count, "Expected exactly one approval with tool_use_id") - - // Verify approval details - var approvalID, toolName, status string - var toolInput string - err = db.QueryRow(` - SELECT id, tool_name, tool_input, status - FROM approvals - WHERE tool_use_id = ? AND session_id = ? - `, toolUseID, sessionID).Scan(&approvalID, &toolName, &toolInput, &status) - require.NoError(t, err) - - assert.Equal(t, "test_tool", toolName) - assert.Equal(t, "pending", status) - assert.Contains(t, toolInput, `"test":"data"`) - assert.NotEmpty(t, approvalID) - - t.Logf("Successfully created approval with ID=%s, tool_use_id=%s", approvalID, toolUseID) - }) - - t.Run("AutoApprovalWithDangerousSkip", func(t *testing.T) { - // Enable dangerous skip permissions - _, err := db.Exec(` - UPDATE sessions - SET dangerously_skip_permissions = 1 - WHERE id = ? - `, sessionID) - require.NoError(t, err) - - // Send MCP approval request - toolUseID := "test_use_auto_approve" - req := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - "tool_name": "edit_tool", - "input": map[string]interface{}{"file": "test.txt"}, - "tool_use_id": toolUseID, - }, - }, - } - - body, _ := json.Marshal(req) - httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body)) - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-Session-ID", sessionID) - - resp, err := http.DefaultClient.Do(httpReq) - require.NoError(t, err) - defer resp.Body.Close() - - // Should get immediate response due to auto-approval - var result map[string]interface{} - json.NewDecoder(resp.Body).Decode(&result) - - // Check that the approval was created and auto-approved - var status, comment string - err = db.QueryRow(` - SELECT status, comment - FROM approvals - WHERE tool_use_id = ? AND session_id = ? - `, toolUseID, sessionID).Scan(&status, &comment) - require.NoError(t, err) - - assert.Equal(t, "approved", status) - assert.Contains(t, comment, "dangerous skip permissions") - - // Verify response contains allow behavior - if responseContent, ok := result["result"].(map[string]interface{}); ok { - if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 { - if textContent, ok := content[0].(map[string]interface{}); ok { - if text, ok := textContent["text"].(string); ok { - var responseData map[string]interface{} - json.Unmarshal([]byte(text), &responseData) - assert.Equal(t, "allow", responseData["behavior"]) - } - } - } - } - - // Disable dangerous skip for cleanup - _, err = db.Exec(` - UPDATE sessions - SET dangerously_skip_permissions = 0 - WHERE id = ? - `, sessionID) - require.NoError(t, err) - }) - - t.Run("MultipleApprovalsDifferentToolUseIDs", func(t *testing.T) { - // Create multiple approval requests with different tool_use_ids - toolUseIDs := []string{"multi_1", "multi_2", "multi_3"} - - for _, toolUseID := range toolUseIDs { - req := map[string]interface{}{ - "jsonrpc": "2.0", - "id": toolUseID, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - "tool_name": "multi_tool", - "input": map[string]interface{}{"id": toolUseID}, - "tool_use_id": toolUseID, - }, - }, - } - - body, _ := json.Marshal(req) - httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body)) - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-Session-ID", sessionID) - - // Send requests in background - go func() { - client := &http.Client{Timeout: 1 * time.Second} - client.Do(httpReq) - }() - } - - // Wait for approvals to be created - time.Sleep(500 * time.Millisecond) - - // Verify all approvals were created with correct tool_use_ids - for _, toolUseID := range toolUseIDs { - var count int - err := db.QueryRow(` - SELECT COUNT(*) FROM approvals - WHERE tool_use_id = ? AND session_id = ? - `, toolUseID, sessionID).Scan(&count) - require.NoError(t, err) - assert.Equal(t, 1, count, "Expected approval for tool_use_id=%s", toolUseID) - } - - // Verify total count - var totalCount int - err := db.QueryRow(` - SELECT COUNT(*) FROM approvals - WHERE tool_use_id IN ('multi_1', 'multi_2', 'multi_3') - AND session_id = ? - `, sessionID).Scan(&totalCount) - require.NoError(t, err) - assert.Equal(t, 3, totalCount, "Expected exactly 3 approvals") - }) -} - -func TestMCPPhase4AutoDenyMode(t *testing.T) { - // Set auto-deny mode - os.Setenv("MCP_AUTO_DENY_ALL", "true") - defer os.Unsetenv("MCP_AUTO_DENY_ALL") - - // Setup isolated environment - socketPath := testutil.SocketPath(t, "mcp-phase4-autodeny") - dbPath := testutil.DatabasePath(t, "mcp-phase4-autodeny") - - // Get a free port for HTTP server - httpPort := getFreePort(t) - - // Override environment - os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath) - os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort)) - os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1") - os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API - - // Create isolated config - tempDir := t.TempDir() - os.Setenv("XDG_CONFIG_HOME", tempDir) - configDir := filepath.Join(tempDir, "humanlayer") - require.NoError(t, os.MkdirAll(configDir, 0755)) - configFile := filepath.Join(configDir, "humanlayer.json") - require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644)) - - // Create daemon - d, err := daemon.New() - require.NoError(t, err, "Failed to create daemon") - - // Start daemon in background - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCh := make(chan error, 1) - go func() { - errCh <- d.Run(ctx) - }() - - // Wait for HTTP server to be ready - baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort) - require.Eventually(t, func() bool { - resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL)) - if err == nil { - resp.Body.Close() - return resp.StatusCode == 200 - } - return false - }, 5*time.Second, 100*time.Millisecond, "HTTP server did not start") - - // Open database connection - db, err := sql.Open("sqlite3", dbPath) - require.NoError(t, err) - defer db.Close() - - // Create a test session - sessionID := "test-session-autodeny" - _, err = db.Exec(` - INSERT INTO sessions ( - id, run_id, claude_session_id, query, model, working_dir, - status, created_at, last_activity_at, auto_accept_edits, - dangerously_skip_permissions, max_turns, system_prompt, - custom_instructions, cost_usd, input_tokens, output_tokens, - duration_ms, num_turns, result_content, error_message - ) VALUES ( - ?, 'run-autodeny', 'claude-autodeny', 'test query', 'claude-3-sonnet', '/tmp', - 'running', datetime('now'), datetime('now'), 0, 0, 10, '', - '', 0.0, 0, 0, 0, 0, '', '' - ) - `, sessionID) - require.NoError(t, err) - - t.Run("AutoDenyDoesNotCreateApproval", func(t *testing.T) { - // Send MCP approval request - toolUseID := "test_autodeny" - req := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - "tool_name": "test_tool", - "input": map[string]interface{}{"test": "data"}, - "tool_use_id": toolUseID, - }, - }, - } - - body, _ := json.Marshal(req) - httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body)) - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-Session-ID", sessionID) - - resp, err := http.DefaultClient.Do(httpReq) - require.NoError(t, err) - defer resp.Body.Close() - - // Should get immediate deny response - var result map[string]interface{} - json.NewDecoder(resp.Body).Decode(&result) - - // Verify deny response - if responseContent, ok := result["result"].(map[string]interface{}); ok { - if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 { - if textContent, ok := content[0].(map[string]interface{}); ok { - if text, ok := textContent["text"].(string); ok { - var responseData map[string]interface{} - json.Unmarshal([]byte(text), &responseData) - assert.Equal(t, "deny", responseData["behavior"]) - assert.Contains(t, responseData["message"], "Auto-denied") - } - } - } - } - - // Verify no approval was created in database - var count int - err = db.QueryRow(` - SELECT COUNT(*) FROM approvals - WHERE tool_use_id = ? - `, toolUseID).Scan(&count) - require.NoError(t, err) - assert.Equal(t, 0, count, "No approval should be created in auto-deny mode") - }) -} diff --git a/hld/daemon/mcp_server_integration_test.go b/hld/daemon/mcp_server_integration_test.go deleted file mode 100644 index 7511b1c..0000000 --- a/hld/daemon/mcp_server_integration_test.go +++ /dev/null @@ -1,260 +0,0 @@ -//go:build integration -// +build integration - -package daemon_test - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "os" - "path/filepath" - "testing" - "time" - - "github.com/humanlayer/humanlayer/hld/daemon" - "github.com/humanlayer/humanlayer/hld/internal/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMCPServerFullImplementation(t *testing.T) { - // Setup isolated environment - socketPath := testutil.SocketPath(t, "mcp-full") - _ = testutil.DatabasePath(t, "mcp-full") - - // Get a free port for HTTP server - httpPort := getFreePort(t) - - // Override environment - os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath) - os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort)) - os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1") - os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API - os.Setenv("MCP_AUTO_DENY_ALL", "true") // Enable auto-deny for predictable testing - - // Create isolated config - tempDir := t.TempDir() - os.Setenv("XDG_CONFIG_HOME", tempDir) - configDir := filepath.Join(tempDir, "humanlayer") - require.NoError(t, os.MkdirAll(configDir, 0755)) - configFile := filepath.Join(configDir, "humanlayer.json") - require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644)) - - // Create daemon - d, err := daemon.New() - require.NoError(t, err, "Failed to create daemon") - - // Start daemon in background - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCh := make(chan error, 1) - go func() { - errCh <- d.Run(ctx) - }() - - // Wait for HTTP server to be ready - baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort) - require.Eventually(t, func() bool { - resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL)) - if err == nil { - resp.Body.Close() - return resp.StatusCode == 200 - } - return false - }, 5*time.Second, 100*time.Millisecond, "HTTP server did not start") - - t.Run("ToolsListSchemaValidation", func(t *testing.T) { - // Test that tools/list returns proper schema structure - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/list", - "params": map[string]interface{}{}, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - resp, err := http.Post( - fmt.Sprintf("%s/api/v1/mcp", baseURL), - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Validate the tool schema structure - res := result["result"].(map[string]interface{}) - tools := res["tools"].([]interface{}) - require.Len(t, tools, 1) - - tool := tools[0].(map[string]interface{}) - assert.Equal(t, "request_approval", tool["name"]) - assert.Equal(t, "Request permission to execute a tool", tool["description"]) - - // Check input schema structure - inputSchema := tool["inputSchema"].(map[string]interface{}) - assert.Equal(t, "object", inputSchema["type"]) - - properties := inputSchema["properties"].(map[string]interface{}) - assert.Contains(t, properties, "tool_name") - assert.Contains(t, properties, "input") - assert.Contains(t, properties, "tool_use_id") - - // Verify required fields - required := inputSchema["required"].([]interface{}) - assert.Len(t, required, 3) - assert.Contains(t, required, "tool_name") - assert.Contains(t, required, "input") - assert.Contains(t, required, "tool_use_id") - - // Check annotations (mark3labs specific) - if annotations, ok := tool["annotations"].(map[string]interface{}); ok { - assert.NotNil(t, annotations["destructiveHint"]) - assert.NotNil(t, annotations["openWorldHint"]) - } - }) - - t.Run("AutoDenyResponseStructure", func(t *testing.T) { - // Test that auto-deny returns proper JSON structure - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - "tool_name": "test_tool", - "input": map[string]interface{}{"command": "ls -la"}, - "tool_use_id": "test_use_123", - }, - }, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - resp, err := http.Post( - fmt.Sprintf("%s/api/v1/mcp", baseURL), - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Validate response structure - res := result["result"].(map[string]interface{}) - content := res["content"].([]interface{}) - require.Len(t, content, 1) - - contentItem := content[0].(map[string]interface{}) - assert.Equal(t, "text", contentItem["type"]) - - // Parse and validate the JSON in the text field - text := contentItem["text"].(string) - var approvalResponse map[string]interface{} - err = json.Unmarshal([]byte(text), &approvalResponse) - require.NoError(t, err) - - assert.Equal(t, "deny", approvalResponse["behavior"]) - assert.Equal(t, "Auto-denied for testing", approvalResponse["message"]) - }) - - t.Run("SessionIDHeaderExtraction", func(t *testing.T) { - // Test that X-Session-ID header is properly handled - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 3, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - "tool_name": "test_with_session", - "input": map[string]interface{}{"test": "data"}, - "tool_use_id": "session_test_456", - }, - }, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - req, err := http.NewRequest("POST", - fmt.Sprintf("%s/api/v1/mcp", baseURL), - bytes.NewBuffer(body)) - require.NoError(t, err) - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Session-ID", "test-session-789") - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Should still get auto-deny response (session ID doesn't affect auto-deny) - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Verify we got a valid response (session ID was accepted) - assert.Contains(t, result, "result") - assert.NotContains(t, result, "error") - }) - - t.Run("MissingRequiredFields", func(t *testing.T) { - // Test that missing required fields return appropriate errors - reqBody := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 4, - "method": "tools/call", - "params": map[string]interface{}{ - "name": "request_approval", - "arguments": map[string]interface{}{ - // Missing tool_use_id - "tool_name": "incomplete_tool", - "input": map[string]interface{}{"test": "data"}, - }, - }, - } - - body, err := json.Marshal(reqBody) - require.NoError(t, err) - - resp, err := http.Post( - fmt.Sprintf("%s/api/v1/mcp", baseURL), - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - var result map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&result) - require.NoError(t, err) - - // Should still work in auto-deny mode (gets empty string for missing field) - // but in real mode would be problematic - if errField, hasError := result["error"]; hasError { - // If there's an error, it should be about the missing field - errMap := errField.(map[string]interface{}) - assert.Contains(t, errMap["message"], "required") - } else { - // In auto-deny mode, it might still process with empty tool_use_id - res := result["result"].(map[string]interface{}) - assert.NotNil(t, res) - } - }) -} diff --git a/hld/daemon/mcp_session_header_test.go b/hld/daemon/mcp_session_header_test.go deleted file mode 100644 index 32abff1..0000000 --- a/hld/daemon/mcp_session_header_test.go +++ /dev/null @@ -1,282 +0,0 @@ -//go:build integration - -package daemon_test - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "os" - "os/exec" - "strings" - "sync" - "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -// HeaderCaptureServer captures all HTTP request headers for analysis -type HeaderCaptureServer struct { - mu sync.Mutex - requests []RequestCapture -} - -type RequestCapture struct { - Method string - Path string - Headers map[string][]string - Body []byte - Timestamp time.Time -} - -func (s *HeaderCaptureServer) Handler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Capture request - body, _ := io.ReadAll(r.Body) - r.Body = io.NopCloser(bytes.NewReader(body)) - - capture := RequestCapture{ - Method: r.Method, - Path: r.URL.Path, - Headers: r.Header.Clone(), - Body: body, - Timestamp: time.Now(), - } - - s.mu.Lock() - s.requests = append(s.requests, capture) - s.mu.Unlock() - - // Log headers - fmt.Printf("\n[%s %s]\n", r.Method, r.URL.Path) - fmt.Println("Headers:") - for k, v := range r.Header { - fmt.Printf(" %s: %s\n", k, strings.Join(v, ", ")) - } - - // Parse JSON-RPC request - var req map[string]interface{} - json.Unmarshal(body, &req) - - method, _ := req["method"].(string) - id := req["id"] - - // Create response - response := map[string]interface{}{ - "jsonrpc": "2.0", - "id": id, - } - - // Handle MCP methods - switch method { - case "initialize": - response["result"] = map[string]interface{}{ - "protocolVersion": "2025-03-26", - "serverInfo": map[string]interface{}{ - "name": "test-server", - "version": "1.0.0", - }, - "capabilities": map[string]interface{}{ - "tools": map[string]interface{}{}, - }, - } - case "tools/list": - response["result"] = map[string]interface{}{ - "tools": []interface{}{ - map[string]interface{}{ - "name": "test_tool", - "description": "A test tool", - "inputSchema": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{}, - }, - }, - }, - } - case "tools/call": - response["result"] = map[string]interface{}{ - "content": []interface{}{ - map[string]interface{}{ - "type": "text", - "text": "Test response", - }, - }, - } - default: - response["error"] = map[string]interface{}{ - "code": -32601, - "message": "Method not found", - } - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) - } -} - -func (s *HeaderCaptureServer) GetRequests() []RequestCapture { - s.mu.Lock() - defer s.mu.Unlock() - return append([]RequestCapture{}, s.requests...) -} - -func TestMCPHeaderTransmission(t *testing.T) { - // Skip if Claude CLI is not available - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("Claude CLI not available") - } - - // Create header capture server - captureServer := &HeaderCaptureServer{} - - // Start HTTP server - gin.SetMode(gin.ReleaseMode) - router := gin.New() - router.Any("/mcp", gin.WrapF(captureServer.Handler())) - - listener, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - port := listener.Addr().(*net.TCPAddr).Port - baseURL := fmt.Sprintf("http://127.0.0.1:%d", port) - - server := &http.Server{Handler: router} - go server.Serve(listener) - defer server.Shutdown(context.Background()) - - // Create MCP config file with custom headers - tempDir := t.TempDir() - mcpConfigPath := fmt.Sprintf("%s/mcp-config.json", tempDir) - - testSessionID := "test-session-123" - mcpConfig := map[string]interface{}{ - "mcpServers": map[string]interface{}{ - "test": map[string]interface{}{ - "type": "http", - "url": fmt.Sprintf("%s/mcp", baseURL), - "headers": map[string]string{ - "X-Session-ID": testSessionID, - "X-Custom-Header": "custom-value", - "Authorization": "Bearer test-token", - }, - }, - }, - } - - configBytes, _ := json.MarshalIndent(mcpConfig, "", " ") - err = os.WriteFile(mcpConfigPath, configBytes, 0644) - require.NoError(t, err) - - t.Logf("MCP Config:\n%s", string(configBytes)) - - // Launch Claude with MCP config - cmd := exec.Command("claude", - "--print", "test", - "--mcp-config", mcpConfigPath, - "--max-turns", "1", - "--model", "sonnet", - "--output-format", "json", - ) - cmd.Dir = tempDir - - output, err := cmd.CombinedOutput() - t.Logf("Claude output:\n%s", string(output)) - - // Allow some time for any async operations - time.Sleep(500 * time.Millisecond) - - // Analyze captured requests - requests := captureServer.GetRequests() - t.Logf("\n=== Captured %d MCP Requests ===", len(requests)) - - sessionIDFound := false - customHeaderFound := false - authHeaderFound := false - - for i, req := range requests { - t.Logf("\nRequest #%d: %s %s", i+1, req.Method, req.Path) - - // Check for our custom headers - if sessionIDs, ok := req.Headers["X-Session-Id"]; ok && len(sessionIDs) > 0 { - sessionIDFound = true - sessionID := sessionIDs[0] - t.Logf(" ✓ X-Session-ID: %s", sessionID) - if sessionID != testSessionID { - t.Errorf(" ✗ Session ID mismatch: got %s, want %s", sessionID, testSessionID) - } - } - - if customs, ok := req.Headers["X-Custom-Header"]; ok && len(customs) > 0 { - customHeaderFound = true - t.Logf(" ✓ X-Custom-Header: %s", customs[0]) - } - - if auths, ok := req.Headers["Authorization"]; ok && len(auths) > 0 { - authHeaderFound = true - t.Logf(" ✓ Authorization: %s", auths[0]) - } - - // Log all headers for debugging - t.Log(" All headers:") - for k, v := range req.Headers { - t.Logf(" %s: %s", k, strings.Join(v, ", ")) - } - } - - // Verdict - t.Log("\n=== VERDICT ===") - if sessionIDFound && customHeaderFound && authHeaderFound { - t.Log("✓ All custom headers are transmitted correctly") - t.Log("✓ Session ID correlation via headers WILL work") - } else { - t.Log("✗ Custom headers are NOT being transmitted") - if !sessionIDFound { - t.Error("X-Session-ID header not found in MCP requests") - } - if !customHeaderFound { - t.Error("X-Custom-Header not found in MCP requests") - } - if !authHeaderFound { - t.Error("Authorization header not found in MCP requests") - } - - t.Log("\n=== RECOMMENDATIONS ===") - t.Log("1. Embed session ID in the MCP server URL path") - t.Log("2. Use unique MCP server instances per session") - t.Log("3. Implement session correlation via MCP protocol messages") - } -} - -func TestMCPSessionCorrelationAlternatives(t *testing.T) { - t.Log("\n=== Alternative Session Correlation Methods ===") - - t.Log("\n1. URL Path Embedding:") - t.Log(" - Change MCP endpoint to: /api/v1/mcp/:session_id") - t.Log(" - Extract session ID from URL path in handler") - t.Log(" - Pro: Simple, reliable, no header dependency") - t.Log(" - Con: Requires URL generation per session") - - t.Log("\n2. MCP Protocol Session:") - t.Log(" - Use MCP's initialize response to establish session") - t.Log(" - Store mapping: mcp_session_id -> humanlayer_session_id") - t.Log(" - Pro: Protocol-native solution") - t.Log(" - Con: Requires stateful MCP server") - - t.Log("\n3. Token-based Correlation:") - t.Log(" - Generate unique token per session") - t.Log(" - Pass token in MCP server name or URL") - t.Log(" - Pro: Secure, unique per session") - t.Log(" - Con: Token management complexity") - - t.Log("\n4. Process-based Correlation:") - t.Log(" - Track Claude process ID") - t.Log(" - Map process to session at launch") - t.Log(" - Pro: OS-level tracking") - t.Log(" - Con: Complex, platform-specific") -} diff --git a/hld/daemon/mcp_tool_use_id_integration_test.go b/hld/daemon/mcp_tool_use_id_integration_test.go deleted file mode 100644 index e0447c0..0000000 --- a/hld/daemon/mcp_tool_use_id_integration_test.go +++ /dev/null @@ -1,358 +0,0 @@ -//go:build integration - -package daemon_test - -import ( - "bytes" - "context" - "database/sql" - "encoding/json" - "fmt" - "net/http" - "os" - "path/filepath" - "testing" - "time" - - "github.com/humanlayer/humanlayer/hld/daemon" - "github.com/humanlayer/humanlayer/hld/internal/testutil" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/require" -) - -// TestMCPToolUseIDCorrelation verifies that when an approval is triggered -// by a running Claude Code instance, the tool_use_id is properly set in the database -func TestMCPToolUseIDCorrelation(t *testing.T) { - // Setup isolated environment - socketPath := testutil.SocketPath(t, "mcp-tool-use-id") - dbPath := testutil.DatabasePath(t, "mcp-tool-use-id") - - // Get a free port for HTTP server - httpPort := getFreePort(t) - - // Override environment - os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath) - os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort)) - os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1") - os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API - - // Create isolated config - tempDir := t.TempDir() - os.Setenv("XDG_CONFIG_HOME", tempDir) - configDir := filepath.Join(tempDir, "humanlayer") - require.NoError(t, os.MkdirAll(configDir, 0755)) - configFile := filepath.Join(configDir, "humanlayer.json") - require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644)) - - // Create daemon - d, err := daemon.New() - require.NoError(t, err, "Failed to create daemon") - - // Start daemon in background - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCh := make(chan error, 1) - go func() { - errCh <- d.Run(ctx) - }() - - // Wait for daemon to be ready - require.Eventually(t, func() bool { - // Check if the HTTP health endpoint is responding - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/health", httpPort)) - if err == nil && resp != nil { - resp.Body.Close() - return resp.StatusCode == 200 - } - return false - }, 10*time.Second, 100*time.Millisecond, "Daemon did not start") - - // Open database connection - db, err := sql.Open("sqlite3", dbPath) - require.NoError(t, err) - defer db.Close() - - // We'll use daemon's REST API to launch sessions properly - - t.Run("SingleApprovalWithToolUseID", func(t *testing.T) { - // Clear any existing approvals - _, err = db.Exec("DELETE FROM approvals") - require.NoError(t, err) - - // Create temp directory for session - testWorkDir := t.TempDir() - - // Prepare session creation request for REST API - createReq := map[string]interface{}{ - "query": "Write 'Hello World' to a file called test.txt and then exit", - "model": "sonnet", - "permission_prompt_tool": "mcp__codelayer__request_approval", - "max_turns": 3, - "working_dir": testWorkDir, - "mcp_config": map[string]interface{}{ - "mcp_servers": map[string]interface{}{ - "codelayer": map[string]interface{}{ - "type": "http", - "url": fmt.Sprintf("http://127.0.0.1:%d/api/v1/mcp", httpPort), - }, - }, - }, - } - - // Send REST API request to create session - reqBody, _ := json.Marshal(createReq) - httpReq, err := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d/api/v1/sessions", httpPort), bytes.NewBuffer(reqBody)) - require.NoError(t, err) - httpReq.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(httpReq) - require.NoError(t, err) - defer resp.Body.Close() - - // Check response status - require.Equal(t, http.StatusCreated, resp.StatusCode, "Expected 201 Created") - - // Parse response - var createResp map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&createResp) - require.NoError(t, err) - - // Get session ID from response - data := createResp["data"].(map[string]interface{}) - sessionID := data["session_id"].(string) - runID := data["run_id"].(string) - t.Logf("Launched session: %s with run_id: %s", sessionID, runID) - - // Let Claude run for a bit to trigger approvals - t.Log("Waiting for Claude to trigger approvals...") - time.Sleep(5 * time.Second) - - // Now check the database for approvals - rows, err := db.Query(` - SELECT id, session_id, tool_name, tool_use_id, status, comment - FROM approvals - ORDER BY created_at DESC - `) - require.NoError(t, err) - defer rows.Close() - - var approvals []struct { - ID string - SessionID string - ToolName string - ToolUseID sql.NullString - Status string - Comment sql.NullString - } - - for rows.Next() { - var a struct { - ID string - SessionID string - ToolName string - ToolUseID sql.NullString - Status string - Comment sql.NullString - } - err := rows.Scan(&a.ID, &a.SessionID, &a.ToolName, &a.ToolUseID, &a.Status, &a.Comment) - require.NoError(t, err) - approvals = append(approvals, a) - } - - // Log what we found - t.Logf("Found %d approvals in database:", len(approvals)) - for i, a := range approvals { - t.Logf(" Approval %d:", i+1) - t.Logf(" ID: %s", a.ID) - t.Logf(" Session ID: %s", a.SessionID) - t.Logf(" Tool Name: %s", a.ToolName) - t.Logf(" Tool Use ID: %v (Valid: %v)", a.ToolUseID.String, a.ToolUseID.Valid) - t.Logf(" Status: %s", a.Status) - if a.Comment.Valid { - t.Logf(" Comment: %s", a.Comment.String) - } - } - - // Also check conversation events for tool uses - var toolUseCount int - rows2, err := db.Query(` - SELECT tool_id, tool_name - FROM conversation_events - WHERE session_id = ? AND tool_id IS NOT NULL - ORDER BY created_at DESC - `, sessionID) - if err == nil { - defer rows2.Close() - for rows2.Next() { - var toolID, toolName string - if err := rows2.Scan(&toolID, &toolName); err == nil { - toolUseCount++ - t.Logf(" Tool use in events: %s (ID: %s)", toolName, toolID) - } - } - } - t.Logf("Found %d tool uses in conversation_events", toolUseCount) - - // Verify that we have at least one approval - if len(approvals) > 0 { - // Check that tool_use_id is set - for _, a := range approvals { - if !a.ToolUseID.Valid || a.ToolUseID.String == "" { - t.Errorf("Approval %s has no tool_use_id set!", a.ID) - } else { - t.Logf("✓ Approval %s has tool_use_id: %s", a.ID, a.ToolUseID.String) - } - } - } else { - t.Log("No approvals were created - this might indicate the test didn't trigger any tools") - t.Log("This can happen if Claude doesn't attempt to write the file") - } - }) - - t.Run("ParallelApprovalsWithDistinctToolUseIDs", func(t *testing.T) { - // Clear any existing approvals - _, err = db.Exec("DELETE FROM approvals") - require.NoError(t, err) - - // Create temp directory for session - testWorkDir := t.TempDir() - - // Prepare session creation request for REST API - createReq := map[string]interface{}{ - "query": "Create 3 files in parallel: file1.txt with 'One', file2.txt with 'Two', file3.txt with 'Three'. Use parallel tool calls if possible.", - "model": "sonnet", - "permission_prompt_tool": "mcp__codelayer__request_approval", - "max_turns": 3, - "working_dir": testWorkDir, - "mcp_config": map[string]interface{}{ - "mcp_servers": map[string]interface{}{ - "codelayer": map[string]interface{}{ - "type": "http", - "url": fmt.Sprintf("http://127.0.0.1:%d/api/v1/mcp", httpPort), - }, - }, - }, - } - - // Send REST API request to create session - reqBody, _ := json.Marshal(createReq) - httpReq, err := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d/api/v1/sessions", httpPort), bytes.NewBuffer(reqBody)) - require.NoError(t, err) - httpReq.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(httpReq) - require.NoError(t, err) - defer resp.Body.Close() - - // Check response status - require.Equal(t, http.StatusCreated, resp.StatusCode, "Expected 201 Created") - - // Parse response - var createResp map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&createResp) - require.NoError(t, err) - - // Get session ID from response - data := createResp["data"].(map[string]interface{}) - sessionID := data["session_id"].(string) - t.Logf("Launched parallel session: %s", sessionID) - - // Let Claude run for a bit - t.Log("Waiting for parallel operations...") - time.Sleep(7 * time.Second) - - // Check database for approvals - rows, err := db.Query(` - SELECT id, tool_use_id, tool_name - FROM approvals - ORDER BY created_at DESC - `) - require.NoError(t, err) - defer rows.Close() - - var approvals []struct { - ID string - ToolUseID sql.NullString - ToolName string - } - - for rows.Next() { - var a struct { - ID string - ToolUseID sql.NullString - ToolName string - } - err := rows.Scan(&a.ID, &a.ToolUseID, &a.ToolName) - require.NoError(t, err) - approvals = append(approvals, a) - } - - t.Logf("Found %d approvals for parallel operations", len(approvals)) - - // Verify each approval has a unique tool_use_id - toolUseIDMap := make(map[string]bool) - for _, a := range approvals { - if !a.ToolUseID.Valid || a.ToolUseID.String == "" { - t.Errorf("Approval %s has no tool_use_id!", a.ID) - } else { - if toolUseIDMap[a.ToolUseID.String] { - t.Errorf("Duplicate tool_use_id found: %s", a.ToolUseID.String) - } - toolUseIDMap[a.ToolUseID.String] = true - t.Logf("✓ Approval %s has unique tool_use_id: %s", a.ID, a.ToolUseID.String) - } - } - - // Cross-reference with conversation events - var toolUseEvents []struct { - ID string - Name string - } - rows2, err := db.Query(` - SELECT tool_id, tool_name - FROM conversation_events - WHERE session_id = ? AND tool_id IS NOT NULL - `, sessionID) - if err == nil { - defer rows2.Close() - for rows2.Next() { - var toolID, toolName string - if err := rows2.Scan(&toolID, &toolName); err == nil { - toolUseEvents = append(toolUseEvents, struct { - ID string - Name string - }{ID: toolID, Name: toolName}) - } - } - } - - t.Logf("Cross-referencing %d tool_use events with approvals", len(toolUseEvents)) - for _, toolUse := range toolUseEvents { - found := false - for _, a := range approvals { - if a.ToolUseID.Valid && a.ToolUseID.String == toolUse.ID { - found = true - t.Logf("✓ Tool use %s matched with approval %s", toolUse.ID, a.ID) - break - } - } - if !found && toolUse.ID != "" { - t.Logf("⚠ Tool use %s (%s) has no matching approval", toolUse.ID, toolUse.Name) - } - } - }) - - // Cleanup: shutdown daemon - cancel() - select { - case err := <-errCh: - if err != nil && err != context.Canceled { - t.Errorf("Daemon exited with error: %v", err) - } - case <-time.After(5 * time.Second): - t.Error("Daemon did not shut down in time") - } -} diff --git a/hld/go.mod b/hld/go.mod index 7d75f14..9005ae9 100644 --- a/hld/go.mod +++ b/hld/go.mod @@ -11,13 +11,10 @@ require ( github.com/getkin/kin-openapi v0.132.0 github.com/gin-contrib/cors v1.7.6 github.com/gin-gonic/gin v1.10.1 - github.com/golang/mock v1.6.0 github.com/google/uuid v1.6.0 github.com/humanlayer/humanlayer/claudecode-go v0.0.0-00010101000000-000000000000 - github.com/mark3labs/mcp-go v0.37.0 github.com/mattn/go-sqlite3 v1.14.28 github.com/oapi-codegen/runtime v1.1.2 - github.com/r3labs/sse/v2 v2.10.0 github.com/spf13/viper v1.20.1 github.com/stretchr/testify v1.10.0 go.uber.org/mock v0.5.2 @@ -25,8 +22,6 @@ require ( require ( github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect - github.com/bahlo/generic-list-go v0.2.0 // indirect - github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/sonic v1.13.3 // indirect github.com/bytedance/sonic/loader v0.2.4 // indirect github.com/cloudwego/base64x v0.1.5 // indirect @@ -41,7 +36,6 @@ require ( github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/goccy/go-json v0.10.5 // indirect - github.com/invopop/jsonschema v0.13.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect @@ -56,6 +50,7 @@ require ( github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/r3labs/sse/v2 v2.10.0 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect @@ -64,8 +59,6 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.0 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect - github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.18.0 // indirect diff --git a/hld/go.sum b/hld/go.sum index 3764831..3b87480 100644 --- a/hld/go.sum +++ b/hld/go.sum @@ -1,11 +1,7 @@ github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= -github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= -github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= -github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= -github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/sonic v1.13.3 h1:MS8gmaH16Gtirygw7jV91pDCN33NyMrPbN7qiYhEsF0= github.com/bytedance/sonic v1.13.3/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= @@ -49,15 +45,11 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= -github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -75,8 +67,6 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ= -github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= @@ -133,11 +123,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= -github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= -github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= -github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= -github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= @@ -147,37 +132,18 @@ go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTV golang.org/x/arch v0.18.0 h1:WN9poc33zL4AzGxqf8VtpKUnGvMi8O9lhNyBMF/85qc= golang.org/x/arch v0.18.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= diff --git a/hld/mcp/server.go b/hld/mcp/server.go deleted file mode 100644 index 6f06c5f..0000000 --- a/hld/mcp/server.go +++ /dev/null @@ -1,260 +0,0 @@ -package mcp - -import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "net/http" - "os" - "sync" - - "github.com/humanlayer/humanlayer/hld/approval" - "github.com/humanlayer/humanlayer/hld/bus" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" -) - -// contextKey is the type for context keys -type contextKey string - -const ( - // sessionIDKey is the context key for session ID - sessionIDKey contextKey = "session_id" -) - -// ApprovalDecision represents the outcome of an approval request -type ApprovalDecision struct { - Approved bool - Comment string -} - -// MCPServer wraps the mark3labs MCP server -type MCPServer struct { - mcpServer *server.MCPServer - httpServer *server.StreamableHTTPServer - approvalManager approval.Manager - eventBus bus.EventBus - autoDenyAll bool - pendingApprovals sync.Map // map[string]chan ApprovalDecision -} - -// NewMCPServer creates the full MCP server implementation -func NewMCPServer(approvalManager approval.Manager, eventBus bus.EventBus) *MCPServer { - autoDeny := os.Getenv("MCP_AUTO_DENY_ALL") == "true" - - s := &MCPServer{ - approvalManager: approvalManager, - eventBus: eventBus, - autoDenyAll: autoDeny, - } - - // Create MCP server - s.mcpServer = server.NewMCPServer( - "humanlayer-daemon", - "1.0.0", - server.WithToolCapabilities(true), - ) - - // Add request_approval tool - s.mcpServer.AddTool( - mcp.NewTool("request_approval", - mcp.WithDescription("Request permission to execute a tool"), - mcp.WithString("tool_name", - mcp.Description("The name of the tool requesting permission"), - mcp.Required(), - ), - mcp.WithObject("input", - mcp.Description("The input to the tool"), - mcp.Required(), - ), - mcp.WithString("tool_use_id", - mcp.Description("Unique identifier for this tool use"), - mcp.Required(), - ), - ), - s.handleRequestApproval, - ) - - // Create HTTP server (stateless for now) - s.httpServer = server.NewStreamableHTTPServer( - s.mcpServer, - server.WithStateLess(true), - ) - - // Don't start goroutine here - wait for Start() to be called - return s -} - -// Start initializes the MCP server's background processes -func (s *MCPServer) Start(ctx context.Context) { - if s.eventBus != nil { - go s.listenForApprovalDecisions(ctx) - } -} - -func (s *MCPServer) handleRequestApproval(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - toolName := request.GetString("tool_name", "") - input := request.GetArguments()["input"] - toolUseID := request.GetString("tool_use_id", "") - - slog.Info("MCP approval requested", - "tool_name", toolName, - "tool_use_id", toolUseID, - "auto_deny", s.autoDenyAll) - - // Auto-deny takes precedence - if s.autoDenyAll { - slog.Info("Auto-denying approval", "tool_use_id", toolUseID) - - responseData := map[string]interface{}{ - "behavior": "deny", - "message": "Auto-denied for testing", - } - responseJSON, _ := json.Marshal(responseData) - - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: string(responseJSON), - }, - }, - }, nil - } - - // Get session_id from context - sessionID, _ := ctx.Value(sessionIDKey).(string) - if sessionID == "" { - return nil, fmt.Errorf("missing session_id in context") - } - - // Marshal input to JSON - inputJSON, err := json.Marshal(input) - if err != nil { - return nil, fmt.Errorf("failed to marshal input: %w", err) - } - - // Create approval with tool_use_id - approval, err := s.approvalManager.CreateApprovalWithToolUseID(ctx, sessionID, toolName, inputJSON, toolUseID) - if err != nil { - slog.Error("Failed to create approval", "error", err) - return nil, fmt.Errorf("failed to create approval: %w", err) - } - - slog.Info("Created approval", "approval_id", approval.ID, "status", approval.Status) - - // Check if the approval was auto-approved - if approval.Status == "approved" { - // Return allow behavior for auto-approved - responseData := map[string]interface{}{ - "behavior": "allow", - "updatedInput": input, - } - responseJSON, _ := json.Marshal(responseData) - - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: string(responseJSON), - }, - }, - }, nil - } - - // Register for event-driven approval resolution - decisionChan := make(chan ApprovalDecision, 1) - s.pendingApprovals.Store(toolUseID, decisionChan) - defer s.pendingApprovals.Delete(toolUseID) - - // Wait for approval decision - select { - case decision := <-decisionChan: - responseData := map[string]interface{}{ - "behavior": "deny", - "message": decision.Comment, - } - if decision.Approved { - responseData = map[string]interface{}{ - "behavior": "allow", - "updatedInput": input, - } - } - responseJSON, _ := json.Marshal(responseData) - - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: string(responseJSON), - }, - }, - }, nil - - // For the moment, we don't timeout approvals, but in the future - // may choose to add a timeout or determine otherwise for resumed sessions - // case <-time.After(5 * time.Minute): - // return nil, fmt.Errorf("approval timeout") - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -func (s *MCPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Extract session_id from header and add to context - sessionID := r.Header.Get("X-Session-ID") - if sessionID == "" { - // Try to extract from MCP session if available - mcpSessionID := r.Header.Get("Mcp-Session-Id") - if mcpSessionID != "" { - sessionID = mcpSessionID - } - } - - // Add session_id to context for future use - ctx := context.WithValue(r.Context(), sessionIDKey, sessionID) - r = r.WithContext(ctx) - - s.httpServer.ServeHTTP(w, r) -} - -// listenForApprovalDecisions listens for approval resolution events and notifies waiting handlers -func (s *MCPServer) listenForApprovalDecisions(ctx context.Context) { - sub := s.eventBus.Subscribe(ctx, bus.EventFilter{ - Types: []bus.EventType{bus.EventApprovalResolved}, - }) - - for { - select { - case <-ctx.Done(): - slog.Info("MCP approval listener shutting down") - return - case event, ok := <-sub.Channel: - if !ok { - slog.Info("MCP approval listener channel closed") - return - } - toolUseID, _ := event.Data["tool_use_id"].(string) - approved, _ := event.Data["approved"].(bool) - comment, _ := event.Data["response_text"].(string) - - if toolUseID == "" { - continue - } - - // Find pending approval channel - if ch, ok := s.pendingApprovals.Load(toolUseID); ok { - select { - case ch.(chan ApprovalDecision) <- ApprovalDecision{ - Approved: approved, - Comment: comment, - }: - slog.Info("Sent approval decision", "tool_use_id", toolUseID, "approved", approved) - default: - slog.Warn("Channel full or closed", "tool_use_id", toolUseID) - } - } - } - } -} diff --git a/hld/sdk/typescript/src/generated/models/MCPServer.ts b/hld/sdk/typescript/src/generated/models/MCPServer.ts index 4423e43..0287911 100644 --- a/hld/sdk/typescript/src/generated/models/MCPServer.ts +++ b/hld/sdk/typescript/src/generated/models/MCPServer.ts @@ -20,47 +20,30 @@ import { mapValues } from '../runtime'; */ export interface MCPServer { /** - * Server type (http for HTTP servers, omit for stdio) + * Command to execute * @type {string} * @memberof MCPServer */ - type?: string; + command: string; /** - * Command to execute (for stdio servers) - * @type {string} - * @memberof MCPServer - */ - command?: string; - /** - * Command arguments (for stdio servers) + * Command arguments * @type {Array} * @memberof MCPServer */ args?: Array; /** - * Environment variables (for stdio servers) + * Environment variables * @type {{ [key: string]: string; }} * @memberof MCPServer */ env?: { [key: string]: string; }; - /** - * HTTP endpoint URL (for HTTP servers) - * @type {string} - * @memberof MCPServer - */ - url?: string; - /** - * HTTP headers to include (for HTTP servers) - * @type {{ [key: string]: string; }} - * @memberof MCPServer - */ - headers?: { [key: string]: string; }; } /** * Check if a given object implements the MCPServer interface. */ export function instanceOfMCPServer(value: object): value is MCPServer { + if (!('command' in value) || value['command'] === undefined) return false; return true; } @@ -74,12 +57,9 @@ export function MCPServerFromJSONTyped(json: any, ignoreDiscriminator: boolean): } return { - 'type': json['type'] == null ? undefined : json['type'], - 'command': json['command'] == null ? undefined : json['command'], + 'command': json['command'], 'args': json['args'] == null ? undefined : json['args'], 'env': json['env'] == null ? undefined : json['env'], - 'url': json['url'] == null ? undefined : json['url'], - 'headers': json['headers'] == null ? undefined : json['headers'], }; } @@ -94,11 +74,8 @@ export function MCPServerToJSONTyped(value?: MCPServer | null, ignoreDiscriminat return { - 'type': value['type'], 'command': value['command'], 'args': value['args'], 'env': value['env'], - 'url': value['url'], - 'headers': value['headers'], }; } diff --git a/hld/session/continue_inheritance_test.go b/hld/session/continue_inheritance_test.go index 7d862bd..43b88e5 100644 --- a/hld/session/continue_inheritance_test.go +++ b/hld/session/continue_inheritance_test.go @@ -779,123 +779,4 @@ func TestContinueSessionInheritance(t *testing.T) { t.Errorf("Child didn't inherit grandparent title: got %q, want %q", childSession.Title, grandparentTitle) } }) - - t.Run("HTTPMCPServerUpdatesXSessionIDHeader", func(t *testing.T) { - // This test verifies that when continuing a session with HTTP MCP servers, - // the X-Session-ID header is updated to the child session ID, not inherited - // from the parent session ID. - - // Create parent session - parentSessionID := "parent-http-mcp" - parentSession := &store.Session{ - ID: parentSessionID, - RunID: "run-http-mcp", - ClaudeSessionID: "claude-http-mcp", - Status: store.SessionStatusCompleted, - Query: "http mcp query", - Model: "claude-3-opus-20240229", - WorkingDir: "/tmp/test", - CreatedAt: time.Now(), - LastActivityAt: time.Now(), - CompletedAt: &time.Time{}, - } - - if err := sqliteStore.CreateSession(ctx, parentSession); err != nil { - t.Fatalf("Failed to create parent session: %v", err) - } - - // Store HTTP MCP server with X-Session-ID header for parent - // For HTTP servers: Command="http", ArgsJSON=["URL"], EnvJSON=headers - parentMCPServers := []store.MCPServer{ - { - SessionID: parentSessionID, - Name: "http-test-server", - Command: "http", // Indicates HTTP type - ArgsJSON: `["http://localhost:8080/mcp"]`, // URL as single-element array - EnvJSON: `{"X-Session-ID": "parent-http-mcp", "Authorization": "Bearer token123"}`, // Headers - }, - } - if err := sqliteStore.StoreMCPServers(ctx, parentSessionID, parentMCPServers); err != nil { - t.Fatalf("Failed to store MCP servers: %v", err) - } - - // Continue session - req := ContinueSessionConfig{ - ParentSessionID: parentSessionID, - Query: "continue http mcp", - } - - _, _ = manager.ContinueSession(ctx, req) - // Expected to fail due to missing Claude binary - - // Find the child session - sessions, err := sqliteStore.ListSessions(ctx) - if err != nil { - t.Fatalf("Failed to list sessions: %v", err) - } - - var childSession *store.Session - for _, s := range sessions { - if s.ParentSessionID == parentSessionID { - childSession = s - break - } - } - - if childSession == nil { - t.Fatal("Child session not found") - return - } - - // Get MCP servers for child session - childMCPServers, err := sqliteStore.GetMCPServers(ctx, childSession.ID) - if err != nil { - t.Fatalf("Failed to get child MCP servers: %v", err) - } - - // Should have inherited the MCP server - if len(childMCPServers) != 1 { - t.Fatalf("Expected 1 MCP server, got %d", len(childMCPServers)) - } - - childMCPServer := childMCPServers[0] - - // Verify basic inheritance - if childMCPServer.Name != "http-test-server" { - t.Errorf("MCP server name not inherited: got %s, want http-test-server", childMCPServer.Name) - } - if childMCPServer.Command != "http" { - t.Errorf("MCP server type not inherited: got %s, want http", childMCPServer.Command) - } - - // Verify URL was inherited (stored in ArgsJSON) - var childArgs []string - if err := json.Unmarshal([]byte(childMCPServer.ArgsJSON), &childArgs); err != nil { - t.Fatalf("Failed to unmarshal child args: %v", err) - } - if len(childArgs) != 1 || childArgs[0] != "http://localhost:8080/mcp" { - t.Errorf("MCP server URL not inherited: got %v, want [http://localhost:8080/mcp]", childArgs) - } - - // Parse headers (stored in EnvJSON) and verify X-Session-ID was updated - var childHeaders map[string]string - if err := json.Unmarshal([]byte(childMCPServer.EnvJSON), &childHeaders); err != nil { - t.Fatalf("Failed to unmarshal child headers: %v", err) - } - - // CRITICAL: X-Session-ID should be the CHILD session ID, not the parent's - if xSessionID, ok := childHeaders["X-Session-ID"]; !ok { - t.Error("X-Session-ID header missing in child MCP server") - } else if xSessionID != childSession.ID { - t.Errorf("X-Session-ID not updated to child session ID: got %s, want %s", xSessionID, childSession.ID) - t.Log("This is the bug! X-Session-ID should be replaced with the child session ID") - } - - // Other headers should be preserved - if auth, ok := childHeaders["Authorization"]; !ok { - t.Error("Authorization header not inherited") - } else if auth != "Bearer token123" { - t.Errorf("Authorization header value changed: got %s, want Bearer token123", auth) - } - }) } diff --git a/hld/session/manager.go b/hld/session/manager.go index 556f698..a0119a1 100644 --- a/hld/session/manager.go +++ b/hld/session/manager.go @@ -69,42 +69,24 @@ func (m *Manager) LaunchSession(ctx context.Context, config LaunchSessionConfig) claudeConfig := config.SessionConfig // Add HUMANLAYER_RUN_ID and HUMANLAYER_DAEMON_SOCKET to MCP server environment - // For HTTP servers, inject session ID header if claudeConfig.MCPConfig != nil { slog.Debug("configuring MCP servers", "count", len(claudeConfig.MCPConfig.MCPServers)) for name, server := range claudeConfig.MCPConfig.MCPServers { - // Check if this is an HTTP MCP server - if server.Type == "http" { - // For HTTP servers, inject session ID header if not already set - if server.Headers == nil { - server.Headers = make(map[string]string) - } - // Only inject if not already set (allow override) - if _, exists := server.Headers["X-Session-ID"]; !exists { - server.Headers["X-Session-ID"] = sessionID - } - slog.Debug("configured HTTP MCP server", - "name", name, - "url", server.URL, - "session_id", sessionID) - } else { - // For stdio servers, add environment variables - if server.Env == nil { - server.Env = make(map[string]string) - } - server.Env["HUMANLAYER_RUN_ID"] = runID - // Add daemon socket path so MCP servers connect to the correct daemon - if m.socketPath != "" { - server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath - } - slog.Debug("configured stdio MCP server", - "name", name, - "command", server.Command, - "args", server.Args, - "run_id", runID, - "socket_path", m.socketPath) + if server.Env == nil { + server.Env = make(map[string]string) + } + server.Env["HUMANLAYER_RUN_ID"] = runID + // Add daemon socket path so MCP servers connect to the correct daemon + if m.socketPath != "" { + server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath } claudeConfig.MCPConfig.MCPServers[name] = server + slog.Debug("configured MCP server", + "name", name, + "command", server.Command, + "args", server.Args, + "run_id", runID, + "socket_path", m.socketPath) } } else { slog.Debug("no MCP config provided") @@ -160,11 +142,7 @@ func (m *Manager) LaunchSession(ctx context.Context, config LaunchSessionConfig) if claudeConfig.MCPConfig != nil { mcpServerCount = len(claudeConfig.MCPConfig.MCPServers) for name, server := range claudeConfig.MCPConfig.MCPServers { - if server.Type == "http" { - mcpServersDetail += fmt.Sprintf("[%s: type=http url=%s headers=%v] ", name, server.URL, server.Headers) - } else { - mcpServersDetail += fmt.Sprintf("[%s: cmd=%s args=%v env=%v] ", name, server.Command, server.Args, server.Env) - } + mcpServersDetail += fmt.Sprintf("[%s: cmd=%s args=%v env=%v] ", name, server.Command, server.Args, server.Env) } } slog.Info("launching Claude session with configuration", @@ -236,15 +214,8 @@ func (m *Manager) LaunchSession(ctx context.Context, config LaunchSessionConfig) // Reconcile any existing approvals for this run_id if m.approvalReconciler != nil { go func() { - // Give the session a moment to start (with cancellation support) - select { - case <-time.After(2 * time.Second): - // Continue with reconciliation - case <-ctx.Done(): - // Context cancelled, exit early - return - } - + // Give the session a moment to start + time.Sleep(2 * time.Second) if err := m.approvalReconciler.ReconcileApprovalsForSession(ctx, runID); err != nil { slog.Error("failed to reconcile approvals for session", "session_id", sessionID, @@ -1212,24 +1183,10 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig env = map[string]string{} } - // Check if this is an HTTP server (stored with command="http") - if server.Command == "http" { - // HTTP server - extract URL from args and headers from env - var urls []string - if err := json.Unmarshal([]byte(server.ArgsJSON), &urls); err == nil && len(urls) > 0 { - config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{ - Type: "http", - URL: urls[0], - Headers: env, // Headers were stored in EnvJSON - } - } - } else { - // Traditional stdio server - config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{ - Command: server.Command, - Args: args, - Env: env, - } + config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{ + Command: server.Command, + Args: args, + Env: env, } } slog.Debug("inherited MCP servers from parent session", @@ -1298,28 +1255,15 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig } // Add run_id and daemon socket to MCP server environments - // For HTTP servers, inject session ID header - if config.MCPConfig != nil { for name, server := range config.MCPConfig.MCPServers { - // Check if this is an HTTP MCP server - if server.Type == "http" { - // For HTTP servers, always set session ID header to child session ID - if server.Headers == nil { - server.Headers = make(map[string]string) - } - // Always set X-Session-ID to the new child session ID (replaces inherited parent ID) - server.Headers["X-Session-ID"] = sessionID - } else { - // For stdio servers, add environment variables - if server.Env == nil { - server.Env = make(map[string]string) - } - server.Env["HUMANLAYER_RUN_ID"] = runID - // Add daemon socket path so MCP servers connect to the correct daemon - if m.socketPath != "" { - server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath - } + if server.Env == nil { + server.Env = make(map[string]string) + } + server.Env["HUMANLAYER_RUN_ID"] = runID + // Add daemon socket path so MCP servers connect to the correct daemon + if m.socketPath != "" { + server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath } config.MCPConfig.MCPServers[name] = server } @@ -1395,15 +1339,8 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig // Reconcile any existing approvals for this run_id (same run_id is reused for continuations) if m.approvalReconciler != nil { go func() { - // Give the session a moment to start (with cancellation support) - select { - case <-time.After(2 * time.Second): - // Continue with reconciliation - case <-ctx.Done(): - // Context cancelled, exit early - return - } - + // Give the session a moment to start + time.Sleep(2 * time.Second) if err := m.approvalReconciler.ReconcileApprovalsForSession(ctx, runID); err != nil { slog.Error("failed to reconcile approvals for continued session", "session_id", sessionID, diff --git a/hld/store/migration_test.go b/hld/store/migration_test.go deleted file mode 100644 index 8d37cba..0000000 --- a/hld/store/migration_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package store_test - -import ( - "context" - "testing" - - "github.com/humanlayer/humanlayer/hld/store" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMigration14_ToolUseID(t *testing.T) { - // Create an in-memory database for testing - s, err := store.NewSQLiteStore(":memory:") - require.NoError(t, err) - defer func() { _ = s.Close() }() - - // First create a session to satisfy foreign key constraint - session := &store.Session{ - ID: "test-session-1", - RunID: "test-run-1", - Query: "test query", - Status: store.SessionStatusRunning, - } - err = s.CreateSession(context.Background(), session) - require.NoError(t, err, "Should be able to create session") - - // Create a test approval with tool_use_id - toolUseID := "test-tool-use-id-123" - approval := &store.Approval{ - ID: "test-approval-1", - RunID: "test-run-1", - SessionID: "test-session-1", - ToolUseID: &toolUseID, - Status: store.ApprovalStatusLocalPending, - ToolName: "test-tool", - ToolInput: []byte(`{"test": "data"}`), - } - - // Create the approval - err = s.CreateApproval(context.Background(), approval) - require.NoError(t, err, "Should be able to create approval with tool_use_id") - - // Retrieve the approval - retrieved, err := s.GetApproval(context.Background(), "test-approval-1") - require.NoError(t, err) - require.NotNil(t, retrieved) - - // Verify tool_use_id was saved and retrieved correctly - assert.NotNil(t, retrieved.ToolUseID, "ToolUseID should not be nil") - if retrieved.ToolUseID != nil { - assert.Equal(t, toolUseID, *retrieved.ToolUseID, "ToolUseID should match") - } - - // Create another session for the second approval - session2 := &store.Session{ - ID: "test-session-2", - RunID: "test-run-2", - Query: "test query 2", - Status: store.SessionStatusRunning, - } - err = s.CreateSession(context.Background(), session2) - require.NoError(t, err, "Should be able to create second session") - - // Test creating approval without tool_use_id (nullable field) - approval2 := &store.Approval{ - ID: "test-approval-2", - RunID: "test-run-2", - SessionID: "test-session-2", - ToolUseID: nil, // Explicitly nil - Status: store.ApprovalStatusLocalPending, - ToolName: "test-tool-2", - ToolInput: []byte(`{"test": "data2"}`), - } - - err = s.CreateApproval(context.Background(), approval2) - require.NoError(t, err, "Should be able to create approval without tool_use_id") - - // Retrieve and verify it's nil - retrieved2, err := s.GetApproval(context.Background(), "test-approval-2") - require.NoError(t, err) - assert.Nil(t, retrieved2.ToolUseID, "ToolUseID should be nil when not provided") -} diff --git a/hld/store/sqlite.go b/hld/store/sqlite.go index 1eb37d9..151bb0f 100644 --- a/hld/store/sqlite.go +++ b/hld/store/sqlite.go @@ -695,75 +695,6 @@ func (s *SQLiteStore) applyMigrations() error { slog.Info("Migration 13 applied successfully") } - // Migration 14: Add tool_use_id column to approvals table - if currentVersion < 14 { - slog.Info("Applying migration 14: Add tool_use_id column to approvals table") - - // Check if column already exists for idempotency - var columnExists int - err = s.db.QueryRow(` - SELECT COUNT(*) FROM pragma_table_info('approvals') - WHERE name = 'tool_use_id' - `).Scan(&columnExists) - if err != nil { - return fmt.Errorf("failed to check tool_use_id column: %w", err) - } - - // Only add column if it doesn't exist - if columnExists == 0 { - _, err = s.db.Exec(` - ALTER TABLE approvals - ADD COLUMN tool_use_id TEXT - `) - if err != nil { - return fmt.Errorf("failed to add tool_use_id column: %w", err) - } - } - - // Create index for efficient lookups - _, err = s.db.Exec(` - CREATE INDEX IF NOT EXISTS idx_approvals_tool_use_id - ON approvals(tool_use_id) - WHERE tool_use_id IS NOT NULL - `) - if err != nil { - return fmt.Errorf("failed to create tool_use_id index: %w", err) - } - - // Update existing approvals to populate tool_use_id from correlated events - _, err = s.db.Exec(` - UPDATE approvals - SET tool_use_id = ( - SELECT ce.tool_id - FROM conversation_events ce - WHERE ce.approval_id = approvals.id - AND ce.tool_id IS NOT NULL - LIMIT 1 - ) - WHERE EXISTS ( - SELECT 1 - FROM conversation_events ce - WHERE ce.approval_id = approvals.id - AND ce.tool_id IS NOT NULL - ) - `) - if err != nil { - slog.Warn("Failed to populate tool_use_id for existing approvals (non-critical)", "error", err) - // This is non-critical as existing approvals may not have correlation - } - - // Record migration - _, err = s.db.Exec(` - INSERT INTO schema_version (version, description) - VALUES (14, 'Add tool_use_id column to approvals table for direct correlation') - `) - if err != nil { - return fmt.Errorf("failed to record migration 14: %w", err) - } - - slog.Info("Migration 14 applied successfully") - } - return nil } @@ -1825,8 +1756,8 @@ func (s *SQLiteStore) CorrelateApproval(ctx context.Context, sessionID string, t return nil } -// LinkConversationEventToApprovalUsingToolID correlates an approval with a specific tool call by tool_id -func (s *SQLiteStore) LinkConversationEventToApprovalUsingToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error { +// CorrelateApprovalByToolID correlates an approval with a specific tool call by tool_id +func (s *SQLiteStore) CorrelateApprovalByToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error { // Update the tool call directly by tool_id updateQuery := ` UPDATE conversation_events @@ -1969,14 +1900,14 @@ func (s *SQLiteStore) CreateApproval(ctx context.Context, approval *Approval) er query := ` INSERT INTO approvals ( - id, run_id, session_id, tool_use_id, status, created_at, - tool_name, tool_input, comment - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + id, run_id, session_id, status, created_at, + tool_name, tool_input + ) VALUES (?, ?, ?, ?, ?, ?, ?) ` _, err := s.db.ExecContext(ctx, query, - approval.ID, approval.RunID, approval.SessionID, approval.ToolUseID, approval.Status.String(), approval.CreatedAt, - approval.ToolName, string(approval.ToolInput), approval.Comment, + approval.ID, approval.RunID, approval.SessionID, approval.Status.String(), approval.CreatedAt, + approval.ToolName, string(approval.ToolInput), ) if err != nil { return fmt.Errorf("failed to create approval: %w", err) @@ -1987,20 +1918,19 @@ func (s *SQLiteStore) CreateApproval(ctx context.Context, approval *Approval) er // GetApproval retrieves an approval by ID func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, error) { query := ` - SELECT id, run_id, session_id, tool_use_id, status, created_at, responded_at, + SELECT id, run_id, session_id, status, created_at, responded_at, tool_name, tool_input, comment FROM approvals WHERE id = ? ` var approval Approval - var toolUseID sql.NullString var respondedAt sql.NullTime var comment sql.NullString var statusStr string var toolInputStr string err := s.db.QueryRowContext(ctx, query, id).Scan( - &approval.ID, &approval.RunID, &approval.SessionID, &toolUseID, &statusStr, + &approval.ID, &approval.RunID, &approval.SessionID, &statusStr, &approval.CreatedAt, &respondedAt, &approval.ToolName, &toolInputStr, &comment, ) @@ -2018,9 +1948,6 @@ func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, er } // Handle nullable fields - if toolUseID.Valid { - approval.ToolUseID = &toolUseID.String - } if respondedAt.Valid { approval.RespondedAt = &respondedAt.Time } @@ -2033,7 +1960,7 @@ func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, er // GetPendingApprovals retrieves all pending approvals for a session func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string) ([]*Approval, error) { query := ` - SELECT id, run_id, session_id, tool_use_id, status, created_at, responded_at, + SELECT id, run_id, session_id, status, created_at, responded_at, tool_name, tool_input, comment FROM approvals WHERE session_id = ? AND status = ? @@ -2049,14 +1976,13 @@ func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string) var approvals []*Approval for rows.Next() { var approval Approval - var toolUseID sql.NullString var respondedAt sql.NullTime var comment sql.NullString var statusStr string var toolInputStr string err := rows.Scan( - &approval.ID, &approval.RunID, &approval.SessionID, &toolUseID, &statusStr, + &approval.ID, &approval.RunID, &approval.SessionID, &statusStr, &approval.CreatedAt, &respondedAt, &approval.ToolName, &toolInputStr, &comment, ) @@ -2071,9 +1997,6 @@ func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string) } // Handle nullable fields - if toolUseID.Valid { - approval.ToolUseID = &toolUseID.String - } if respondedAt.Valid { approval.RespondedAt = &respondedAt.Time } @@ -2141,47 +2064,22 @@ func MCPServersFromConfig(sessionID string, config map[string]claudecode.MCPServ servers := make([]MCPServer, 0, len(config)) for _, name := range names { server := config[name] + argsJSON, err := json.Marshal(server.Args) + if err != nil { + return nil, fmt.Errorf("failed to marshal args: %w", err) + } - // For HTTP servers, store the configuration differently - // We'll use Command field to store the type, ArgsJSON for URL, and EnvJSON for headers - var command string - var argsJSON string - var envJSON string - - if server.Type == "http" { - // HTTP server - command = "http" // Use "http" as the command to indicate HTTP type - argsJSON = fmt.Sprintf(`["%s"]`, server.URL) // Store URL as single-element array - - // Store headers in EnvJSON - headersData, err := json.Marshal(server.Headers) - if err != nil { - return nil, fmt.Errorf("failed to marshal headers: %w", err) - } - envJSON = string(headersData) - } else { - // Traditional stdio server - command = server.Command - - argsData, err := json.Marshal(server.Args) - if err != nil { - return nil, fmt.Errorf("failed to marshal args: %w", err) - } - argsJSON = string(argsData) - - envData, err := json.Marshal(server.Env) - if err != nil { - return nil, fmt.Errorf("failed to marshal env: %w", err) - } - envJSON = string(envData) + envJSON, err := json.Marshal(server.Env) + if err != nil { + return nil, fmt.Errorf("failed to marshal env: %w", err) } servers = append(servers, MCPServer{ SessionID: sessionID, Name: name, - Command: command, - ArgsJSON: argsJSON, - EnvJSON: envJSON, + Command: server.Command, + ArgsJSON: string(argsJSON), + EnvJSON: string(envJSON), }) } return servers, nil diff --git a/hld/store/store.go b/hld/store/store.go index fe32f8b..90a0dbe 100644 --- a/hld/store/store.go +++ b/hld/store/store.go @@ -31,7 +31,7 @@ type ConversationStore interface { GetToolCallByID(ctx context.Context, toolID string) (*ConversationEvent, error) MarkToolCallCompleted(ctx context.Context, toolID string, sessionID string) error CorrelateApproval(ctx context.Context, sessionID string, toolName string, approvalID string) error - LinkConversationEventToApprovalUsingToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error + CorrelateApprovalByToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error UpdateApprovalStatus(ctx context.Context, approvalID string, status string) error // MCP server operations @@ -201,7 +201,6 @@ type Approval struct { ID string `json:"id"` RunID string `json:"run_id"` SessionID string `json:"session_id"` - ToolUseID *string `json:"tool_use_id,omitempty"` Status ApprovalStatus `json:"status"` CreatedAt time.Time `json:"created_at"` RespondedAt *time.Time `json:"responded_at,omitempty"` diff --git a/hlyr/src/commands/launch.ts b/hlyr/src/commands/launch.ts index 73ed7e6..7640c6c 100644 --- a/hlyr/src/commands/launch.ts +++ b/hlyr/src/commands/launch.ts @@ -46,16 +46,13 @@ export const launchCommand = async (query: string, options: LaunchOptions = {}) try { // Build MCP config (approvals enabled by default unless explicitly disabled) - // Phase 6: Using HTTP MCP endpoint instead of stdio - const daemonPort = process.env.HUMANLAYER_DAEMON_HTTP_PORT || '7777' const mcpConfig = options.approvals !== false ? { mcpServers: { - codelayer: { - type: 'http', - url: `http://localhost:${daemonPort}/api/v1/mcp`, - // Session ID will be added as header by Claude Code + approvals: { + command: 'npx', + args: ['humanlayer', 'mcp', 'claude_approvals'], }, }, } @@ -69,7 +66,7 @@ export const launchCommand = async (query: string, options: LaunchOptions = {}) working_dir: options.workingDir || process.cwd(), max_turns: options.maxTurns, mcp_config: mcpConfig, - permission_prompt_tool: mcpConfig ? 'mcp__codelayer__request_approval' : undefined, + permission_prompt_tool: mcpConfig ? 'mcp__approvals__request_permission' : undefined, dangerously_skip_permissions: options.dangerouslySkipPermissions, dangerously_skip_permissions_timeout: options.dangerouslySkipPermissionsTimeout ? parseInt(options.dangerouslySkipPermissionsTimeout) * 60 * 1000 diff --git a/hlyr/src/index.ts b/hlyr/src/index.ts index 0aea40c..294cb26 100644 --- a/hlyr/src/index.ts +++ b/hlyr/src/index.ts @@ -10,6 +10,7 @@ import { launchCommand } from './commands/launch.js' import { alertCommand } from './commands/alert.js' import { thoughtsCommand } from './commands/thoughts.js' import { joinWaitlistCommand } from './commands/joinWaitlist.js' +import { startDefaultMCPServer, startClaudeApprovalsMCPServer } from './mcp.js' import { getDefaultConfigPath, resolveFullConfig, @@ -66,7 +67,7 @@ async function authenticate(printSelectedProject: boolean = false) { program.name('humanlayer').description('HumanLayer, but on your command-line.').version(VERSION) -const UNPROTECTED_COMMANDS = ['config', 'login', 'thoughts', 'join-waitlist', 'launch'] +const UNPROTECTED_COMMANDS = ['config', 'login', 'thoughts', 'join-waitlist', 'launch', 'mcp'] program.hook('preAction', async (thisCmd, actionCmd) => { // Get the full command path by traversing up the command hierarchy @@ -171,6 +172,36 @@ program .option('--daemon-socket ', 'Path to daemon socket') .action(alertCommand) +const mcpCommand = program.command('mcp').description('MCP server functionality') + +mcpCommand + .command('serve') + .description('Start the default MCP server for contact_human functionality') + .action(startDefaultMCPServer) + +mcpCommand + .command('claude_approvals') + .description('Start the Claude approvals MCP server for permission requests') + .action(startClaudeApprovalsMCPServer) + +mcpCommand + .command('wrapper') + .description('Wrap an existing MCP server with human approval functionality (not implemented yet)') + .action(() => { + console.log('MCP wrapper functionality is not implemented yet.') + console.log('This will allow wrapping any existing MCP server with human approval.') + process.exit(1) + }) + +mcpCommand + .command('inspector') + .description('Run MCP inspector for debugging MCP servers') + .argument('[command]', 'MCP server command to inspect', 'serve') + .action(command => { + const args = ['@modelcontextprotocol/inspector', 'node', 'dist/index.js', 'mcp', command] + spawn('npx', args, { stdio: 'inherit', cwd: process.cwd() }) + }) + // Add thoughts command thoughtsCommand(program) diff --git a/hlyr/src/mcp.ts b/hlyr/src/mcp.ts new file mode 100644 index 0000000..52f8ff6 --- /dev/null +++ b/hlyr/src/mcp.ts @@ -0,0 +1,274 @@ +import { Server } from '@modelcontextprotocol/sdk/server/index.js' +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' +import { + CallToolRequestSchema, + ErrorCode, + ListToolsRequestSchema, + McpError, +} from '@modelcontextprotocol/sdk/types.js' +import { humanlayer } from '@humanlayer/sdk' +import { resolveFullConfig } from './config.js' +import { DaemonClient } from './daemonClient.js' +import { logger } from './mcpLogger.js' + +function validateAuth(): void { + const config = resolveFullConfig({}) + + if (!config.api_key) { + console.error('Error: No HumanLayer API token found.') + console.error('Please set HUMANLAYER_API_KEY environment variable or run `humanlayer login`') + process.exit(1) + } +} + +/** + * Start the default MCP server that provides contact_human functionality + * Uses web UI by default when no contact channel is configured + */ +export async function startDefaultMCPServer() { + validateAuth() + + const server = new Server( + { + name: 'humanlayer-standalone', + version: '1.0.0', + }, + { + capabilities: { + tools: {}, + }, + }, + ) + + const resolvedConfig = resolveFullConfig({}) + + const hl = humanlayer({ + apiKey: resolvedConfig.api_key, + ...(resolvedConfig.api_base_url && { apiBaseUrl: resolvedConfig.api_base_url }), + ...(resolvedConfig.run_id && { runId: resolvedConfig.run_id }), + ...(Object.keys(resolvedConfig.contact_channel).length > 0 && { + contactChannel: resolvedConfig.contact_channel, + }), + }) + + server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: [ + { + name: 'contact_human', + description: 'Contact a human for assistance', + inputSchema: { + type: 'object', + properties: { + message: { type: 'string' }, + }, + required: ['message'], + }, + }, + ], + } + }) + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'contact_human') { + const response = await hl.fetchHumanResponse({ + spec: { + msg: request.params.arguments?.message, + }, + }) + + return { + content: [ + { + type: 'text', + text: response, + }, + ], + } + } + + throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name') + }) + + const transport = new StdioServerTransport() + await server.connect(transport) +} + +/** + * Start the Claude approvals MCP server that provides request_permission functionality + * Returns responses in the format required by Claude Code SDK + * + * This now uses local approvals through the daemon instead of HumanLayer API + */ +export async function startClaudeApprovalsMCPServer() { + // No auth validation needed - uses local daemon + logger.info('Starting Claude approvals MCP server') + logger.info('Environment variables', { + HUMANLAYER_DAEMON_SOCKET: process.env.HUMANLAYER_DAEMON_SOCKET, + HUMANLAYER_RUN_ID: process.env.HUMANLAYER_RUN_ID, + }) + + const server = new Server( + { + name: 'humanlayer-claude-local-approvals', + version: '1.0.0', + }, + { + capabilities: { + tools: {}, + }, + }, + ) + + // Create daemon client with socket path from environment or config + // The daemon sets HUMANLAYER_DAEMON_SOCKET for MCP servers it launches + const resolvedConfig = resolveFullConfig({}) + const socketPath = process.env.HUMANLAYER_DAEMON_SOCKET || resolvedConfig.daemon_socket + logger.info('Creating daemon client', { socketPath }) + const daemonClient = new DaemonClient(socketPath) + + server.setRequestHandler(ListToolsRequestSchema, async () => { + logger.info('ListTools request received') + const tools = [ + { + name: 'request_permission', + description: 'Request permission to perform an action', + inputSchema: { + type: 'object', + properties: { + tool_name: { type: 'string' }, + input: { type: 'object' }, + }, + required: ['tool_name', 'input'], + }, + }, + ] + logger.info('Returning tools', { tools }) + return { tools } + }) + + server.setRequestHandler(CallToolRequestSchema, async request => { + logger.debug('Received tool call request', { name: request.params.name }) + + if (request.params.name === 'request_permission') { + const toolName: string | undefined = request.params.arguments?.tool_name + + if (!toolName) { + logger.error('Invalid tool name in request_permission', request.params.arguments) + throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name requesting permissions') + } + + const input: Record = request.params.arguments?.input || {} + + // Get run ID from environment (set by Claude Code) + const runId = process.env.HUMANLAYER_RUN_ID + if (!runId) { + logger.error('HUMANLAYER_RUN_ID not set in environment') + throw new McpError(ErrorCode.InternalError, 'HUMANLAYER_RUN_ID not set') + } + + logger.info('Processing approval request', { runId, toolName }) + + try { + // Connect to daemon + logger.debug('Connecting to daemon...') + await daemonClient.connect() + logger.debug('Connected to daemon') + + // Create approval request + logger.debug('Creating approval request...', { runId, toolName }) + const createResponse = await daemonClient.createApproval(runId, toolName, input) + const approvalId = createResponse.approval_id + logger.info('Created approval', { approvalId }) + + // Poll for approval status + let approved = false + let comment = '' + let polling = true + + while (polling) { + try { + // Get the specific approval by ID + logger.debug('Fetching approval status...', { approvalId }) + const approval = (await daemonClient.getApproval(approvalId)) as { + id: string + status: string + comment?: string + } + + logger.debug('Approval status', { status: approval.status }) + + if (approval.status !== 'pending') { + // Approval has been resolved + approved = approval.status === 'approved' + comment = approval.comment || '' + polling = false + logger.info('Approval resolved', { + approvalId, + status: approval.status, + approved, + }) + } else { + // Still pending, wait and poll again + logger.debug('Approval still pending, polling again...') + await new Promise(resolve => setTimeout(resolve, 1000)) + } + } catch (error) { + logger.error('Failed to get approval status', { error, approvalId }) + // Re-throw the error since this is a critical failure + throw new McpError(ErrorCode.InternalError, 'Failed to get approval status') + } + } + + if (!approved) { + logger.info('Approval denied', { approvalId, comment }) + return { + content: [ + { + type: 'text', + text: JSON.stringify({ + behavior: 'deny', + message: comment || 'Request denied by human reviewer', + }), + }, + ], + } + } + + logger.info('Approval granted', { approvalId }) + return { + content: [ + { + type: 'text', + text: JSON.stringify({ + behavior: 'allow', + updatedInput: input, + }), + }, + ], + } + } catch (error) { + logger.error('Failed to process approval', error) + throw new McpError( + ErrorCode.InternalError, + `Failed to process approval: ${error instanceof Error ? error.message : String(error)}`, + ) + } finally { + logger.debug('Closing daemon connection') + daemonClient.close() + } + } + + throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name') + }) + + const transport = new StdioServerTransport() + + try { + await server.connect(transport) + logger.info('MCP server connected and ready') + } catch (error) { + logger.error('Failed to start MCP server', error) + throw error + } +} diff --git a/humanlayer-wui/src/hooks/useApprovals.ts b/humanlayer-wui/src/hooks/useApprovals.ts index 5f79d62..bbdeb6a 100644 --- a/humanlayer-wui/src/hooks/useApprovals.ts +++ b/humanlayer-wui/src/hooks/useApprovals.ts @@ -95,16 +95,6 @@ export function useApprovalsWithSubscription(sessionId?: string): UseApprovalsRe onEvent: event => { if (!isSubscribed) return - // Phase 7: Debug logging to verify tool_use_id flows through - if (event.type === 'new_approval' || event.type === 'approval_resolved') { - console.debug('Approval event with tool_use_id:', { - type: event.type, - approval_id: event.data?.approval_id, - tool_use_id: event.data?.tool_use_id, - data: event.data, - }) - } - // Handle different event types switch (event.type) { case 'new_approval': diff --git a/humanlayer-wui/src/hooks/useSessionLauncher.ts b/humanlayer-wui/src/hooks/useSessionLauncher.ts index f980139..77abeac 100644 --- a/humanlayer-wui/src/hooks/useSessionLauncher.ts +++ b/humanlayer-wui/src/hooks/useSessionLauncher.ts @@ -1,7 +1,6 @@ import { create } from 'zustand' import { daemonClient } from '@/lib/daemon' import type { LaunchSessionRequest } from '@/lib/daemon/types' -import { getDaemonUrl } from '@/lib/daemon/http-config' import { useHotkeysContext } from 'react-hotkeys-hook' import { SessionTableHotkeysScope } from '@/components/internal/SessionTable' import { exists } from '@tauri-apps/plugin-fs' @@ -142,13 +141,11 @@ export const useSessionLauncher = create((set, get) => ({ set({ isLaunching: true, error: undefined }) // Build MCP config (approvals enabled by default) - // Use HTTP-based MCP server built into the daemon - const daemonUrl = await getDaemonUrl() const mcpConfig = { mcpServers: { approvals: { - type: 'http', - url: `${daemonUrl}/api/v1/mcp`, + command: 'npx', + args: ['humanlayer', 'mcp', 'claude_approvals'], }, }, } @@ -160,7 +157,7 @@ export const useSessionLauncher = create((set, get) => ({ model: config.model || undefined, max_turns: config.maxTurns || undefined, mcp_config: mcpConfig, - permission_prompt_tool: 'mcp__approvals__request_approval', + permission_prompt_tool: 'mcp__approvals__request_permission', } const response = await daemonClient.launchSession(request)