mirror of
https://github.com/humanlayer/humanlayer.git
synced 2025-08-20 19:01:22 +03:00
Revert "sHTTP Conversion FTW"
This commit is contained in:
@@ -26,12 +26,7 @@ When invoked with a parameter like `gh_username:branchName`:
|
||||
- Run setup: `make -C WORKTREE setup`
|
||||
- Initialize thoughts: `cd WORKTREE && npx humanlayer thoughts init --directory humanlayer`
|
||||
|
||||
5. **Build CodeLayer in background**:
|
||||
- Run: `make -C WORKTREE codelayer-dev &`
|
||||
- This builds CodeLayer in the background so it's ready when needed
|
||||
- If port 1420 is already in use, run with a different port: `VITE_PORT=1421 make -C WORKTREE codelayer-dev &`
|
||||
|
||||
6. **Launch Claude Code session**:
|
||||
5. **Launch Claude Code session**:
|
||||
- If a ticket number was found: `npx humanlayer launch --model opus -w WORKTREE "We're working on ENG-XXXX - fetch the issue and then await further instructions"`
|
||||
- Otherwise: `npx humanlayer launch --model opus -w WORKTREE "We're reviewing the branch BRANCHNAME - please familiarize yourself with the changes and await further instructions"`
|
||||
|
||||
@@ -40,7 +35,6 @@ When invoked with a parameter like `gh_username:branchName`:
|
||||
- If worktree already exists, inform the user they need to remove it first
|
||||
- If remote fetch fails, check if the username/repo exists
|
||||
- If setup fails, provide the error but continue with the launch
|
||||
- If CodeLayer fails with "Port 1420 is already in use", use the VITE_PORT alternative command
|
||||
|
||||
## Example Usage
|
||||
|
||||
@@ -52,5 +46,4 @@ This will:
|
||||
- Add 'samdickson22' as a remote
|
||||
- Create worktree at `~/wt/humanlayer/eng-1696`
|
||||
- Set up the environment
|
||||
- Build CodeLayer in the background
|
||||
- Launch Claude with: "We're working on ENG-1696 - fetch the issue and then await further instructions"
|
||||
|
||||
@@ -27,17 +27,10 @@ const (
|
||||
)
|
||||
|
||||
// MCPServer represents a single MCP server configuration
|
||||
// It can be either a stdio-based server (with command/args/env) or an HTTP server (with type/url/headers)
|
||||
type MCPServer struct {
|
||||
// For stdio-based servers
|
||||
Command string `json:"command,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
|
||||
// For HTTP servers
|
||||
Type string `json:"type,omitempty"` // "http" for HTTP servers
|
||||
URL string `json:"url,omitempty"` // The HTTP endpoint URL
|
||||
Headers map[string]string `json:"headers,omitempty"` // HTTP headers to include
|
||||
}
|
||||
|
||||
// MCPConfig represents the MCP configuration structure
|
||||
|
||||
@@ -6,7 +6,7 @@ The daemon logs are in ~/.humanlayer/logs/daemon-*.log (timestamped files create
|
||||
|
||||
WUI logs (which include daemon stderr output) are in:
|
||||
- Development: `~/.humanlayer/logs/wui-{branch}/codelayer.log`
|
||||
- Production: Platform-specific log directories, e.g. ~/Library/Logs/dev.humanlayer.wui.nightly/CodeLayer-Nightly.log
|
||||
- Production: Platform-specific log directories
|
||||
|
||||
It uses a database at ~/.humanlayer/*.db - you can access it with sqlite3 to inspect progress and debug things.
|
||||
|
||||
@@ -47,9 +47,3 @@ echo '{"jsonrpc":"2.0","method":"getSessionLeaves","params":{},"id":1}' | nc -U
|
||||
|
||||
|
||||
For testing guidelines and database isolation requirements, see TESTING.md
|
||||
|
||||
|
||||
### Go style guidelines
|
||||
|
||||
- any async or long-running goroutine should accept a context.Context as a parameter and handle cancellation gracefully
|
||||
- context and CancelFuncs should never be stored on structs, always passed as the first parameter to a function
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestSessionHandlers_CreateSession(t *testing.T) {
|
||||
McpConfig: &api.MCPConfig{
|
||||
McpServers: &map[string]api.MCPServer{
|
||||
"test-server": {
|
||||
Command: stringPtr("node"),
|
||||
Command: "node",
|
||||
Args: &[]string{"server.js"},
|
||||
Env: &map[string]string{
|
||||
"DEBUG": "true",
|
||||
|
||||
@@ -201,22 +201,8 @@ func (m *Mapper) MCPConfigFromAPI(config *api.MCPConfig) *claudecode.MCPConfig {
|
||||
servers := make(map[string]claudecode.MCPServer)
|
||||
if config.McpServers != nil {
|
||||
for name, server := range *config.McpServers {
|
||||
mcpServer := claudecode.MCPServer{}
|
||||
|
||||
// Map HTTP server fields
|
||||
if server.Type != nil {
|
||||
mcpServer.Type = *server.Type
|
||||
}
|
||||
if server.Url != nil {
|
||||
mcpServer.URL = *server.Url
|
||||
}
|
||||
if server.Headers != nil {
|
||||
mcpServer.Headers = *server.Headers
|
||||
}
|
||||
|
||||
// Map stdio server fields
|
||||
if server.Command != nil {
|
||||
mcpServer.Command = *server.Command
|
||||
mcpServer := claudecode.MCPServer{
|
||||
Command: server.Command,
|
||||
}
|
||||
if server.Args != nil {
|
||||
mcpServer.Args = *server.Args
|
||||
@@ -224,7 +210,6 @@ func (m *Mapper) MCPConfigFromAPI(config *api.MCPConfig) *claudecode.MCPConfig {
|
||||
if server.Env != nil {
|
||||
mcpServer.Env = *server.Env
|
||||
}
|
||||
|
||||
servers[name] = mcpServer
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1242,39 +1242,26 @@ components:
|
||||
|
||||
MCPServer:
|
||||
type: object
|
||||
required:
|
||||
- command
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: Server type (http for HTTP servers, omit for stdio)
|
||||
example: http
|
||||
command:
|
||||
type: string
|
||||
description: Command to execute (for stdio servers)
|
||||
description: Command to execute
|
||||
example: mcp-server-filesystem
|
||||
args:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Command arguments (for stdio servers)
|
||||
description: Command arguments
|
||||
example: ["--read-only", "/home/user"]
|
||||
env:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: Environment variables (for stdio servers)
|
||||
description: Environment variables
|
||||
example:
|
||||
DEBUG: "true"
|
||||
url:
|
||||
type: string
|
||||
description: HTTP endpoint URL (for HTTP servers)
|
||||
example: http://localhost:7777/api/v1/mcp
|
||||
headers:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: HTTP headers to include (for HTTP servers)
|
||||
example:
|
||||
X-Session-ID: "session-123"
|
||||
|
||||
# Event Types
|
||||
EventType:
|
||||
|
||||
@@ -447,23 +447,14 @@ type MCPConfig struct {
|
||||
|
||||
// MCPServer defines model for MCPServer.
|
||||
type MCPServer struct {
|
||||
// Args Command arguments (for stdio servers)
|
||||
// Args Command arguments
|
||||
Args *[]string `json:"args,omitempty"`
|
||||
|
||||
// Command Command to execute (for stdio servers)
|
||||
Command *string `json:"command,omitempty"`
|
||||
// Command Command to execute
|
||||
Command string `json:"command"`
|
||||
|
||||
// Env Environment variables (for stdio servers)
|
||||
// Env Environment variables
|
||||
Env *map[string]string `json:"env,omitempty"`
|
||||
|
||||
// Headers HTTP headers to include (for HTTP servers)
|
||||
Headers *map[string]string `json:"headers,omitempty"`
|
||||
|
||||
// Type Server type (http for HTTP servers, omit for stdio)
|
||||
Type *string `json:"type,omitempty"`
|
||||
|
||||
// Url HTTP endpoint URL (for HTTP servers)
|
||||
Url *string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// RecentPath defines model for RecentPath.
|
||||
|
||||
@@ -232,7 +232,7 @@ func (m *manager) correlateApproval(ctx context.Context, approval *store.Approva
|
||||
}
|
||||
|
||||
// Correlate by tool ID
|
||||
if err := m.store.LinkConversationEventToApprovalUsingToolID(ctx, approval.SessionID, toolCall.ToolID, approval.ID); err != nil {
|
||||
if err := m.store.CorrelateApprovalByToolID(ctx, approval.SessionID, toolCall.ToolID, approval.ID); err != nil {
|
||||
return fmt.Errorf("failed to correlate approval: %w", err)
|
||||
}
|
||||
|
||||
@@ -258,20 +258,15 @@ func (m *manager) publishNewApprovalEvent(approval *store.Approval) {
|
||||
// publishApprovalResolvedEvent publishes an event when an approval is resolved
|
||||
func (m *manager) publishApprovalResolvedEvent(approval *store.Approval, approved bool, responseText string) {
|
||||
if m.eventBus != nil {
|
||||
eventData := map[string]interface{}{
|
||||
"approval_id": approval.ID,
|
||||
"session_id": approval.SessionID,
|
||||
"approved": approved,
|
||||
"response_text": responseText,
|
||||
}
|
||||
// Include tool_use_id if present
|
||||
if approval.ToolUseID != nil {
|
||||
eventData["tool_use_id"] = *approval.ToolUseID
|
||||
}
|
||||
event := bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Timestamp: time.Now(),
|
||||
Data: eventData,
|
||||
Data: map[string]interface{}{
|
||||
"approval_id": approval.ID,
|
||||
"session_id": approval.SessionID,
|
||||
"approved": approved,
|
||||
"response_text": responseText,
|
||||
},
|
||||
}
|
||||
m.eventBus.Publish(event)
|
||||
}
|
||||
@@ -286,109 +281,6 @@ func (m *manager) updateSessionStatus(ctx context.Context, sessionID, status str
|
||||
return m.store.UpdateSession(ctx, sessionID, updates)
|
||||
}
|
||||
|
||||
// CreateApprovalWithToolUseID creates an approval with tool_use_id field
|
||||
func (m *manager) CreateApprovalWithToolUseID(ctx context.Context, sessionID, toolName string, toolInput json.RawMessage, toolUseID string) (*store.Approval, error) {
|
||||
// Check if auto-accept is enabled (either mode)
|
||||
session, err := m.store.GetSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("session not found: %s", sessionID)
|
||||
}
|
||||
|
||||
status := store.ApprovalStatusLocalPending
|
||||
comment := ""
|
||||
|
||||
// Check dangerously skip permissions first (overrides edit mode)
|
||||
if session.DangerouslySkipPermissions {
|
||||
// Check if it has an expiry and if it's expired
|
||||
if session.DangerouslySkipPermissionsExpiresAt != nil && time.Now().After(*session.DangerouslySkipPermissionsExpiresAt) {
|
||||
// Expired - disable it
|
||||
update := store.SessionUpdate{
|
||||
DangerouslySkipPermissions: &[]bool{false}[0],
|
||||
DangerouslySkipPermissionsExpiresAt: &[]*time.Time{nil}[0],
|
||||
}
|
||||
if err := m.store.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to disable expired dangerously skip permissions", "session_id", session.ID, "error", err)
|
||||
}
|
||||
// Continue with normal approval
|
||||
} else {
|
||||
// Dangerously skip permissions is active (no expiry or not expired)
|
||||
status = store.ApprovalStatusLocalApproved
|
||||
comment = "Auto-accepted (dangerous skip permissions enabled)"
|
||||
}
|
||||
} else if session.AutoAcceptEdits && isEditTool(toolName) {
|
||||
// Regular auto-accept edits mode
|
||||
status = store.ApprovalStatusLocalApproved
|
||||
comment = "Auto-accepted (auto-accept mode enabled)"
|
||||
}
|
||||
|
||||
// Create approval with tool_use_id
|
||||
approval := &store.Approval{
|
||||
ID: "local-" + uuid.New().String(),
|
||||
RunID: session.RunID,
|
||||
SessionID: sessionID,
|
||||
ToolUseID: &toolUseID,
|
||||
Status: status,
|
||||
CreatedAt: time.Now(),
|
||||
ToolName: toolName,
|
||||
ToolInput: toolInput,
|
||||
Comment: comment,
|
||||
}
|
||||
|
||||
// Store it
|
||||
if err := m.store.CreateApproval(ctx, approval); err != nil {
|
||||
return nil, fmt.Errorf("failed to store approval: %w", err)
|
||||
}
|
||||
|
||||
// Publish event for real-time updates
|
||||
m.publishNewApprovalEvent(approval)
|
||||
|
||||
if err := m.store.LinkConversationEventToApprovalUsingToolID(ctx, sessionID, toolUseID, approval.ID); err != nil {
|
||||
// Log but don't fail
|
||||
// TODO(1): Don't ship if above LinkConversationEventToApprovalUsingToolID does not retry
|
||||
// it's possible, albeit unlikely, that the raw_event has not made it to
|
||||
// conversation_events yet
|
||||
return nil, fmt.Errorf("failed to correlate approval: %w", err)
|
||||
}
|
||||
|
||||
// Handle status-specific post-creation tasks
|
||||
switch status {
|
||||
case store.ApprovalStatusLocalPending:
|
||||
// Update session status to waiting_input for pending approvals
|
||||
if err := m.updateSessionStatus(ctx, session.ID, store.SessionStatusWaitingInput); err != nil {
|
||||
slog.Warn("failed to update session status",
|
||||
"error", err,
|
||||
"session_id", session.ID)
|
||||
}
|
||||
case store.ApprovalStatusLocalApproved:
|
||||
// For auto-approved, update correlation status immediately
|
||||
// Update approval status
|
||||
if err := m.store.UpdateApprovalStatus(ctx, approval.ID, store.ApprovalStatusApproved); err != nil {
|
||||
slog.Warn("failed to update approval status in conversation events",
|
||||
"error", err,
|
||||
"approval_id", approval.ID)
|
||||
}
|
||||
// Publish resolved event for auto-approved
|
||||
m.publishApprovalResolvedEvent(approval, true, comment)
|
||||
}
|
||||
|
||||
logLevel := slog.LevelInfo
|
||||
if status == store.ApprovalStatusLocalApproved {
|
||||
logLevel = slog.LevelDebug // Less noise for auto-approved
|
||||
}
|
||||
slog.Log(ctx, logLevel, "created approval with tool_use_id",
|
||||
"approval_id", approval.ID,
|
||||
"session_id", sessionID,
|
||||
"tool_name", toolName,
|
||||
"tool_use_id", toolUseID,
|
||||
"status", status,
|
||||
"auto_accepted", status == store.ApprovalStatusLocalApproved)
|
||||
|
||||
return approval, nil
|
||||
}
|
||||
|
||||
// isEditTool checks if a tool name is one of the edit tools
|
||||
func isEditTool(toolName string) bool {
|
||||
return toolName == "Edit" || toolName == "Write" || toolName == "MultiEdit"
|
||||
|
||||
@@ -248,7 +248,7 @@ func TestManager_CorrelateApproval(t *testing.T) {
|
||||
mockStore.EXPECT().GetUncorrelatedPendingToolCall(ctx, sessionID, toolName).Return(pendingToolCall, nil)
|
||||
|
||||
// Mock correlating by tool ID
|
||||
mockStore.EXPECT().LinkConversationEventToApprovalUsingToolID(ctx, sessionID, "tool-123", gomock.Any()).Return(nil)
|
||||
mockStore.EXPECT().CorrelateApprovalByToolID(ctx, sessionID, "tool-123", gomock.Any()).Return(nil)
|
||||
|
||||
// Mock event publishing
|
||||
mockEventBus.EXPECT().Publish(gomock.Any())
|
||||
|
||||
@@ -12,9 +12,6 @@ type Manager interface {
|
||||
// Create a new approval
|
||||
CreateApproval(ctx context.Context, runID, toolName string, toolInput json.RawMessage) (string, error)
|
||||
|
||||
// Create approval with tool_use_id (Phase 4)
|
||||
CreateApprovalWithToolUseID(ctx context.Context, sessionID, toolName string, toolInput json.RawMessage, toolUseID string) (*store.Approval, error)
|
||||
|
||||
// Retrieval methods
|
||||
GetPendingApprovals(ctx context.Context, sessionID string) ([]*store.Approval, error)
|
||||
GetApproval(ctx context.Context, id string) (*store.Approval, error)
|
||||
|
||||
@@ -71,10 +71,8 @@ func TestDaemonSubscriptionIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify all clients get subscriber count
|
||||
// Note: MCP server adds 1 subscriber for listening to approval events
|
||||
expectedSubscribers := numClients + 1
|
||||
if subCount := daemon.eventBus.GetSubscriberCount(); subCount != expectedSubscribers {
|
||||
t.Errorf("Expected %d subscribers (including MCP listener), got %d", expectedSubscribers, subCount)
|
||||
if subCount := daemon.eventBus.GetSubscriberCount(); subCount != numClients {
|
||||
t.Errorf("Expected %d subscribers, got %d", numClients, subCount)
|
||||
}
|
||||
|
||||
// Publish an event
|
||||
@@ -268,8 +266,6 @@ func TestDaemonMemoryStability(t *testing.T) {
|
||||
socketPath := testutil.CreateTestSocket(t)
|
||||
t.Setenv("HUMANLAYER_SOCKET_PATH", socketPath)
|
||||
t.Setenv("HUMANLAYER_LOG_LEVEL", "error")
|
||||
// Use in-memory database for tests
|
||||
t.Setenv("HUMANLAYER_DATABASE_PATH", ":memory:")
|
||||
|
||||
// Create and start daemon
|
||||
daemon, err := New()
|
||||
@@ -333,9 +329,9 @@ func TestDaemonMemoryStability(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Check final subscriber count (should be 1 for MCP listener)
|
||||
// Check final subscriber count (should be 0)
|
||||
finalCount := daemon.eventBus.GetSubscriberCount()
|
||||
if finalCount != 1 {
|
||||
t.Errorf("Expected 1 subscriber (MCP listener) after all clients disconnected, got %d", finalCount)
|
||||
if finalCount != 0 {
|
||||
t.Errorf("Expected 0 subscribers after all clients disconnected, got %d", finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/config"
|
||||
"github.com/humanlayer/humanlayer/hld/mcp"
|
||||
"github.com/humanlayer/humanlayer/hld/session"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
)
|
||||
@@ -27,8 +26,6 @@ type HTTPServer struct {
|
||||
sessionHandlers *handlers.SessionHandlers
|
||||
approvalHandlers *handlers.ApprovalHandlers
|
||||
sseHandler *handlers.SSEHandler
|
||||
approvalManager approval.Manager
|
||||
eventBus bus.EventBus
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
@@ -74,8 +71,6 @@ func NewHTTPServer(
|
||||
sessionHandlers: sessionHandlers,
|
||||
approvalHandlers: approvalHandlers,
|
||||
sseHandler: sseHandler,
|
||||
approvalManager: approvalManager,
|
||||
eventBus: eventBus,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,13 +91,6 @@ func (s *HTTPServer) Start(ctx context.Context) error {
|
||||
// Register SSE endpoint directly (not part of strict interface)
|
||||
v1.GET("/stream/events", s.sseHandler.StreamEvents)
|
||||
|
||||
// MCP endpoint (Phase 5: with event-driven approvals)
|
||||
mcpServer := mcp.NewMCPServer(s.approvalManager, s.eventBus)
|
||||
mcpServer.Start(ctx) // Start background processes with context
|
||||
v1.Any("/mcp", func(c *gin.Context) {
|
||||
mcpServer.ServeHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// Create listener first to handle port 0
|
||||
addr := fmt.Sprintf("%s:%d", s.config.HTTPHost, s.config.HTTPPort)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MCPRequestLog captures all details of an MCP request
|
||||
type MCPRequestLog struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers map[string][]string
|
||||
Body json.RawMessage
|
||||
Timestamp time.Time
|
||||
RequestID int
|
||||
}
|
||||
|
||||
// MCPTestServer wraps the real MCP server and logs all requests
|
||||
type MCPTestServer struct {
|
||||
realHandler http.Handler
|
||||
requests []MCPRequestLog
|
||||
requestsMutex sync.Mutex
|
||||
requestCount int
|
||||
}
|
||||
|
||||
func (s *MCPTestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture request details
|
||||
s.requestsMutex.Lock()
|
||||
s.requestCount++
|
||||
requestID := s.requestCount
|
||||
|
||||
// Read body for logging
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Reset body for real handler
|
||||
|
||||
// Log all headers
|
||||
headersCopy := make(map[string][]string)
|
||||
for k, v := range r.Header {
|
||||
headersCopy[k] = v
|
||||
}
|
||||
|
||||
log := MCPRequestLog{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Headers: headersCopy,
|
||||
Body: bodyBytes,
|
||||
Timestamp: time.Now(),
|
||||
RequestID: requestID,
|
||||
}
|
||||
s.requests = append(s.requests, log)
|
||||
s.requestsMutex.Unlock()
|
||||
|
||||
// Log request details to test output
|
||||
fmt.Printf("\n[MCP Request #%d] %s %s\n", requestID, r.Method, r.URL.Path)
|
||||
fmt.Printf("Headers:\n")
|
||||
for k, v := range r.Header {
|
||||
fmt.Printf(" %s: %s\n", k, strings.Join(v, ", "))
|
||||
}
|
||||
if len(bodyBytes) > 0 {
|
||||
fmt.Printf("Body: %s\n", string(bodyBytes))
|
||||
}
|
||||
fmt.Printf("---\n")
|
||||
|
||||
// Forward to real handler
|
||||
s.realHandler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *MCPTestServer) GetRequests() []MCPRequestLog {
|
||||
s.requestsMutex.Lock()
|
||||
defer s.requestsMutex.Unlock()
|
||||
return append([]MCPRequestLog{}, s.requests...)
|
||||
}
|
||||
|
||||
func TestMCPClaudeCodeSessionIDCorrelation(t *testing.T) {
|
||||
// Skip if Claude is not available
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("Claude CLI not available, skipping integration test")
|
||||
}
|
||||
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-claudecode")
|
||||
dbPath := testutil.DatabasePath(t, "mcp-claudecode")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
os.Setenv("MCP_AUTO_DENY_ALL", "true") // Auto-deny for predictable responses
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create MCP test server wrapper
|
||||
mcpTestServer := &MCPTestServer{
|
||||
requests: []MCPRequestLog{},
|
||||
}
|
||||
|
||||
// Custom HTTP server setup to wrap MCP handler
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.New()
|
||||
|
||||
// Add health endpoint
|
||||
router.GET("/api/v1/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Wrap MCP endpoint with test server
|
||||
router.Any("/api/v1/mcp", func(c *gin.Context) {
|
||||
// First time setup - get real handler from daemon
|
||||
if mcpTestServer.realHandler == nil {
|
||||
// Get the real MCP handler from daemon
|
||||
// We'll create a simple MCP handler that auto-denies
|
||||
mcpTestServer.realHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simple MCP response for testing
|
||||
var req map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
method, _ := req["method"].(string)
|
||||
id := req["id"]
|
||||
|
||||
response := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "initialize":
|
||||
response["result"] = map[string]interface{}{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"serverInfo": map[string]interface{}{
|
||||
"name": "test-mcp-server",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"capabilities": map[string]interface{}{
|
||||
"tools": map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
case "tools/list":
|
||||
response["result"] = map[string]interface{}{
|
||||
"tools": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"description": "Request permission to execute a tool",
|
||||
"inputSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"tool_name": map[string]string{"type": "string"},
|
||||
"input": map[string]string{"type": "object"},
|
||||
"tool_use_id": map[string]string{"type": "string"},
|
||||
},
|
||||
"required": []string{"tool_name", "input", "tool_use_id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case "tools/call":
|
||||
// Auto-deny
|
||||
response["result"] = map[string]interface{}{
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": `{"behavior": "deny", "message": "Auto-denied for testing"}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
response["error"] = map[string]interface{}{
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
})
|
||||
}
|
||||
|
||||
mcpTestServer.ServeHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// Start HTTP server
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", httpPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &http.Server{
|
||||
Handler: router,
|
||||
}
|
||||
|
||||
go func() {
|
||||
server.Serve(listener)
|
||||
}()
|
||||
defer server.Shutdown(context.Background())
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create a test session in the database
|
||||
testSessionID := "test-claudecode-session"
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO sessions (
|
||||
id, run_id, claude_session_id, query, model, working_dir,
|
||||
status, created_at, last_activity_at, auto_accept_edits,
|
||||
dangerously_skip_permissions, max_turns, system_prompt,
|
||||
custom_instructions, cost_usd, input_tokens, output_tokens,
|
||||
duration_ms, num_turns, result_content, error_message
|
||||
) VALUES (
|
||||
?, 'run-claudecode', 'claude-test', 'test query', 'claude-3-sonnet', '/tmp',
|
||||
'running', datetime('now'), datetime('now'), 0, 1, 10, '',
|
||||
'', 0.0, 0, 0, 0, 0, '', ''
|
||||
)
|
||||
`, testSessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create claudecode client
|
||||
client, err := claudecode.NewClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Prepare MCP configuration
|
||||
// The claudecode client will write this to a temp file and pass it to claude
|
||||
// We need to match the format claude expects for HTTP MCP servers
|
||||
mcpConfig := &claudecode.MCPConfig{
|
||||
MCPServers: map[string]claudecode.MCPServer{
|
||||
"humanlayer": {
|
||||
Command: "http", // This is just a placeholder
|
||||
Env: map[string]string{
|
||||
// The actual config will be written as JSON with type/url/headers
|
||||
"_config": fmt.Sprintf(`{"type":"http","url":"%s/api/v1/mcp","headers":{"X-Session-ID":"%s"}}`, baseURL, testSessionID),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create session config
|
||||
sessionConfig := claudecode.SessionConfig{
|
||||
Query: "Say 'test complete' and exit",
|
||||
Model: claudecode.ModelSonnet,
|
||||
OutputFormat: claudecode.OutputStreamJSON,
|
||||
MCPConfig: mcpConfig,
|
||||
PermissionPromptTool: "mcp__humanlayer__request_approval",
|
||||
MaxTurns: 1,
|
||||
WorkingDir: tempDir,
|
||||
Verbose: true,
|
||||
}
|
||||
|
||||
// Capture events from Claude
|
||||
var allEvents []claudecode.StreamEvent
|
||||
var eventsMutex sync.Mutex
|
||||
|
||||
// Launch Claude session
|
||||
t.Log("Launching Claude session with MCP config...")
|
||||
session, err := client.Launch(sessionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Capture events in background
|
||||
eventsDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(eventsDone)
|
||||
for event := range session.Events {
|
||||
eventsMutex.Lock()
|
||||
allEvents = append(allEvents, event)
|
||||
eventsMutex.Unlock()
|
||||
|
||||
// Log significant events
|
||||
switch event.Type {
|
||||
case "system":
|
||||
if event.Subtype == "init" {
|
||||
t.Logf("Claude session initialized: ID=%s, Model=%s", event.SessionID, event.Model)
|
||||
}
|
||||
case "mcp_servers":
|
||||
for _, server := range event.MCPServers {
|
||||
t.Logf("MCP Server %s: %s", server.Name, server.Status)
|
||||
}
|
||||
case "result":
|
||||
t.Logf("Session completed: ID=%s, Error=%v", event.SessionID, event.IsError)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for session to complete (with timeout)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
result, err := session.Wait()
|
||||
if err != nil {
|
||||
t.Logf("Session error: %v", err)
|
||||
} else if result != nil {
|
||||
t.Logf("Session result: %s", result.Result)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Session completed
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Log("Session timeout, interrupting...")
|
||||
session.Interrupt()
|
||||
<-done
|
||||
}
|
||||
|
||||
// Wait for events to be processed
|
||||
<-eventsDone
|
||||
|
||||
// Analyze captured MCP requests
|
||||
requests := mcpTestServer.GetRequests()
|
||||
t.Logf("\n=== MCP Request Analysis ===")
|
||||
t.Logf("Total MCP requests: %d", len(requests))
|
||||
|
||||
// Check for session ID in headers
|
||||
sessionIDFound := false
|
||||
var sessionIDHeaders []string
|
||||
|
||||
for i, req := range requests {
|
||||
t.Logf("\nRequest #%d: %s", i+1, req.Method)
|
||||
|
||||
// Check various possible session ID headers
|
||||
possibleHeaders := []string{
|
||||
"X-Session-ID",
|
||||
"X-Session-Id",
|
||||
"Session-ID",
|
||||
"Session-Id",
|
||||
"Mcp-Session-Id",
|
||||
"MCP-Session-ID",
|
||||
}
|
||||
|
||||
for _, header := range possibleHeaders {
|
||||
if values, ok := req.Headers[header]; ok && len(values) > 0 {
|
||||
sessionIDFound = true
|
||||
sessionIDHeaders = append(sessionIDHeaders, fmt.Sprintf("%s: %s", header, values[0]))
|
||||
t.Logf(" ✓ Found session ID header: %s = %s", header, values[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check if session ID is in the request body
|
||||
if len(req.Body) > 0 {
|
||||
var body map[string]interface{}
|
||||
if err := json.Unmarshal(req.Body, &body); err == nil {
|
||||
if sessionID, ok := body["session_id"].(string); ok && sessionID != "" {
|
||||
t.Logf(" ✓ Found session_id in body: %s", sessionID)
|
||||
}
|
||||
if params, ok := body["params"].(map[string]interface{}); ok {
|
||||
if sessionID, ok := params["session_id"].(string); ok && sessionID != "" {
|
||||
t.Logf(" ✓ Found session_id in params: %s", sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Analyze Claude events for session information
|
||||
t.Logf("\n=== Claude Event Analysis ===")
|
||||
t.Logf("Total events captured: %d", len(allEvents))
|
||||
|
||||
var claudeSessionID string
|
||||
for _, event := range allEvents {
|
||||
if event.SessionID != "" && claudeSessionID == "" {
|
||||
claudeSessionID = event.SessionID
|
||||
t.Logf("Claude session ID from events: %s", claudeSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
// Final verdict
|
||||
t.Logf("\n=== VERDICT ===")
|
||||
if sessionIDFound {
|
||||
t.Logf("✓ Session ID IS sent in MCP request headers")
|
||||
t.Logf(" Headers found: %s", strings.Join(sessionIDHeaders, ", "))
|
||||
t.Logf(" Current implementation should work correctly")
|
||||
} else {
|
||||
t.Logf("✗ Session ID is NOT sent in MCP request headers")
|
||||
t.Logf(" Claude session ID: %s", claudeSessionID)
|
||||
t.Logf(" Need to implement alternative correlation mechanism")
|
||||
t.Logf(" Possible solutions:")
|
||||
t.Logf(" 1. Use MCP session initialization to establish mapping")
|
||||
t.Logf(" 2. Pass session ID in MCP server URL path")
|
||||
t.Logf(" 3. Use a unique MCP server per session")
|
||||
}
|
||||
|
||||
// Assert findings
|
||||
if !sessionIDFound {
|
||||
t.Error("Session ID is not being sent in MCP request headers - implementation needs revision")
|
||||
|
||||
// Provide detailed recommendations
|
||||
t.Log("\nRECOMMENDED CHANGES:")
|
||||
t.Log("1. Remove reliance on Session-ID header in MCP server")
|
||||
t.Log("2. Consider embedding session ID in MCP server URL:")
|
||||
t.Log(" - Change URL to: /api/v1/mcp/{session_id}")
|
||||
t.Log(" - Extract session ID from URL path in handler")
|
||||
t.Log("3. Or use MCP session correlation:")
|
||||
t.Log(" - Track MCP session ID from initialize method")
|
||||
t.Log(" - Map MCP session to HumanLayer session")
|
||||
}
|
||||
|
||||
// Additional diagnostics
|
||||
t.Logf("\n=== Additional Diagnostics ===")
|
||||
if len(requests) > 0 {
|
||||
t.Log("First request headers:")
|
||||
for k, v := range requests[0].Headers {
|
||||
t.Logf(" %s: %s", k, strings.Join(v, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,290 +0,0 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/daemon"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMCPStubEndpoint(t *testing.T) {
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp")
|
||||
_ = testutil.DatabasePath(t, "mcp") // Sets HUMANLAYER_DATABASE_PATH
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create daemon
|
||||
d, err := daemon.New()
|
||||
require.NoError(t, err, "Failed to create daemon")
|
||||
|
||||
// Start daemon in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
t.Run("Initialize", func(t *testing.T) {
|
||||
// Test MCP initialize method
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": map[string]interface{}{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": map[string]interface{}{},
|
||||
"clientInfo": map[string]interface{}{
|
||||
"name": "test",
|
||||
"version": "1.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
"application/json",
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify response structure
|
||||
assert.Equal(t, "2.0", result["jsonrpc"])
|
||||
assert.Equal(t, float64(1), result["id"])
|
||||
|
||||
// Check result field
|
||||
res, ok := result["result"].(map[string]interface{})
|
||||
require.True(t, ok, "result field should be a map")
|
||||
|
||||
assert.Equal(t, "2025-03-26", res["protocolVersion"])
|
||||
|
||||
serverInfo, ok := res["serverInfo"].(map[string]interface{})
|
||||
require.True(t, ok, "serverInfo should be a map")
|
||||
assert.Equal(t, "humanlayer-daemon", serverInfo["name"])
|
||||
assert.Equal(t, "1.0.0", serverInfo["version"])
|
||||
})
|
||||
|
||||
t.Run("ToolsList", func(t *testing.T) {
|
||||
// Test tools/list method
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": map[string]interface{}{},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
"application/json",
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check tools list
|
||||
res, ok := result["result"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
tools, ok := res["tools"].([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tools, 1)
|
||||
|
||||
tool := tools[0].(map[string]interface{})
|
||||
assert.Equal(t, "request_approval", tool["name"])
|
||||
assert.Contains(t, tool["description"], "Request permission to execute a tool")
|
||||
})
|
||||
|
||||
t.Run("UnknownMethod", func(t *testing.T) {
|
||||
// Test unknown method
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "unknown/method",
|
||||
"params": map[string]interface{}{},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
"application/json",
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have error response
|
||||
errResp, ok := result["error"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have error field")
|
||||
|
||||
assert.Equal(t, float64(-32601), errResp["code"])
|
||||
assert.Contains(t, errResp["message"], "not found")
|
||||
})
|
||||
|
||||
t.Run("AutoDeny", func(t *testing.T) {
|
||||
// Set auto-deny mode
|
||||
os.Setenv("MCP_AUTO_DENY_ALL", "true")
|
||||
defer os.Unsetenv("MCP_AUTO_DENY_ALL")
|
||||
|
||||
// Restart daemon with auto-deny
|
||||
cancel()
|
||||
|
||||
// Wait for shutdown
|
||||
select {
|
||||
case <-errCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Daemon did not shut down")
|
||||
}
|
||||
|
||||
// Create new daemon with auto-deny
|
||||
d2, err := daemon.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
errCh2 := make(chan error, 1)
|
||||
go func() {
|
||||
errCh2 <- d2.Run(ctx2)
|
||||
}()
|
||||
|
||||
// Wait for server to be ready again
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
|
||||
// Test tools/call with auto-deny
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "test_tool",
|
||||
"input": map[string]interface{}{"test": "data"},
|
||||
"tool_use_id": "test_123",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
"application/json",
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check auto-deny response
|
||||
res, ok := result["result"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
content, ok := res["content"].([]interface{})
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
|
||||
contentItem := content[0].(map[string]interface{})
|
||||
assert.Equal(t, "text", contentItem["type"])
|
||||
|
||||
// Parse the JSON text content
|
||||
text := contentItem["text"].(string)
|
||||
var approval map[string]interface{}
|
||||
err = json.Unmarshal([]byte(text), &approval)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "deny", approval["behavior"])
|
||||
assert.Contains(t, approval["message"], "Auto-denied")
|
||||
})
|
||||
}
|
||||
|
||||
// getFreePort gets a free TCP port for testing
|
||||
func getFreePort(t *testing.T) int {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
return listener.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
@@ -1,404 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/daemon"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMCPPhase4ApprovalCreation(t *testing.T) {
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-phase4")
|
||||
dbPath := testutil.DatabasePath(t, "mcp-phase4")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create daemon
|
||||
d, err := daemon.New()
|
||||
require.NoError(t, err, "Failed to create daemon")
|
||||
|
||||
// Start daemon in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create a test session
|
||||
sessionID := "test-session-phase4"
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO sessions (
|
||||
id, run_id, claude_session_id, query, model, working_dir,
|
||||
status, created_at, last_activity_at, auto_accept_edits,
|
||||
dangerously_skip_permissions, max_turns, system_prompt,
|
||||
custom_instructions, cost_usd, input_tokens, output_tokens,
|
||||
duration_ms, num_turns, result_content, error_message
|
||||
) VALUES (
|
||||
?, 'run-phase4', 'claude-phase4', 'test query', 'claude-3-sonnet', '/tmp',
|
||||
'running', datetime('now'), datetime('now'), 0, 0, 10, '',
|
||||
'', 0.0, 0, 0, 0, 0, '', ''
|
||||
)
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("ApprovalCreatedWithToolUseID", func(t *testing.T) {
|
||||
// Send MCP approval request
|
||||
toolUseID := "test_use_phase4_123"
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "test_tool",
|
||||
"input": map[string]interface{}{"test": "data"},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
// Send request in background (it will block waiting for approval)
|
||||
go func() {
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
client.Do(httpReq)
|
||||
}()
|
||||
|
||||
// Wait for approval to be created
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check database for approval with tool_use_id
|
||||
var count int
|
||||
err := db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Expected exactly one approval with tool_use_id")
|
||||
|
||||
// Verify approval details
|
||||
var approvalID, toolName, status string
|
||||
var toolInput string
|
||||
err = db.QueryRow(`
|
||||
SELECT id, tool_name, tool_input, status
|
||||
FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&approvalID, &toolName, &toolInput, &status)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test_tool", toolName)
|
||||
assert.Equal(t, "pending", status)
|
||||
assert.Contains(t, toolInput, `"test":"data"`)
|
||||
assert.NotEmpty(t, approvalID)
|
||||
|
||||
t.Logf("Successfully created approval with ID=%s, tool_use_id=%s", approvalID, toolUseID)
|
||||
})
|
||||
|
||||
t.Run("AutoApprovalWithDangerousSkip", func(t *testing.T) {
|
||||
// Enable dangerous skip permissions
|
||||
_, err := db.Exec(`
|
||||
UPDATE sessions
|
||||
SET dangerously_skip_permissions = 1
|
||||
WHERE id = ?
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send MCP approval request
|
||||
toolUseID := "test_use_auto_approve"
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "edit_tool",
|
||||
"input": map[string]interface{}{"file": "test.txt"},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should get immediate response due to auto-approval
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
|
||||
// Check that the approval was created and auto-approved
|
||||
var status, comment string
|
||||
err = db.QueryRow(`
|
||||
SELECT status, comment
|
||||
FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&status, &comment)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "approved", status)
|
||||
assert.Contains(t, comment, "dangerous skip permissions")
|
||||
|
||||
// Verify response contains allow behavior
|
||||
if responseContent, ok := result["result"].(map[string]interface{}); ok {
|
||||
if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 {
|
||||
if textContent, ok := content[0].(map[string]interface{}); ok {
|
||||
if text, ok := textContent["text"].(string); ok {
|
||||
var responseData map[string]interface{}
|
||||
json.Unmarshal([]byte(text), &responseData)
|
||||
assert.Equal(t, "allow", responseData["behavior"])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Disable dangerous skip for cleanup
|
||||
_, err = db.Exec(`
|
||||
UPDATE sessions
|
||||
SET dangerously_skip_permissions = 0
|
||||
WHERE id = ?
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("MultipleApprovalsDifferentToolUseIDs", func(t *testing.T) {
|
||||
// Create multiple approval requests with different tool_use_ids
|
||||
toolUseIDs := []string{"multi_1", "multi_2", "multi_3"}
|
||||
|
||||
for _, toolUseID := range toolUseIDs {
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": toolUseID,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "multi_tool",
|
||||
"input": map[string]interface{}{"id": toolUseID},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
// Send requests in background
|
||||
go func() {
|
||||
client := &http.Client{Timeout: 1 * time.Second}
|
||||
client.Do(httpReq)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for approvals to be created
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Verify all approvals were created with correct tool_use_ids
|
||||
for _, toolUseID := range toolUseIDs {
|
||||
var count int
|
||||
err := db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id = ? AND session_id = ?
|
||||
`, toolUseID, sessionID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count, "Expected approval for tool_use_id=%s", toolUseID)
|
||||
}
|
||||
|
||||
// Verify total count
|
||||
var totalCount int
|
||||
err := db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id IN ('multi_1', 'multi_2', 'multi_3')
|
||||
AND session_id = ?
|
||||
`, sessionID).Scan(&totalCount)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, totalCount, "Expected exactly 3 approvals")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMCPPhase4AutoDenyMode(t *testing.T) {
|
||||
// Set auto-deny mode
|
||||
os.Setenv("MCP_AUTO_DENY_ALL", "true")
|
||||
defer os.Unsetenv("MCP_AUTO_DENY_ALL")
|
||||
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-phase4-autodeny")
|
||||
dbPath := testutil.DatabasePath(t, "mcp-phase4-autodeny")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create daemon
|
||||
d, err := daemon.New()
|
||||
require.NoError(t, err, "Failed to create daemon")
|
||||
|
||||
// Start daemon in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Create a test session
|
||||
sessionID := "test-session-autodeny"
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO sessions (
|
||||
id, run_id, claude_session_id, query, model, working_dir,
|
||||
status, created_at, last_activity_at, auto_accept_edits,
|
||||
dangerously_skip_permissions, max_turns, system_prompt,
|
||||
custom_instructions, cost_usd, input_tokens, output_tokens,
|
||||
duration_ms, num_turns, result_content, error_message
|
||||
) VALUES (
|
||||
?, 'run-autodeny', 'claude-autodeny', 'test query', 'claude-3-sonnet', '/tmp',
|
||||
'running', datetime('now'), datetime('now'), 0, 0, 10, '',
|
||||
'', 0.0, 0, 0, 0, 0, '', ''
|
||||
)
|
||||
`, sessionID)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("AutoDenyDoesNotCreateApproval", func(t *testing.T) {
|
||||
// Send MCP approval request
|
||||
toolUseID := "test_autodeny"
|
||||
req := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "test_tool",
|
||||
"input": map[string]interface{}{"test": "data"},
|
||||
"tool_use_id": toolUseID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(req)
|
||||
httpReq, _ := http.NewRequest("POST", baseURL+"/api/v1/mcp", bytes.NewBuffer(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Session-ID", sessionID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should get immediate deny response
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
|
||||
// Verify deny response
|
||||
if responseContent, ok := result["result"].(map[string]interface{}); ok {
|
||||
if content, ok := responseContent["content"].([]interface{}); ok && len(content) > 0 {
|
||||
if textContent, ok := content[0].(map[string]interface{}); ok {
|
||||
if text, ok := textContent["text"].(string); ok {
|
||||
var responseData map[string]interface{}
|
||||
json.Unmarshal([]byte(text), &responseData)
|
||||
assert.Equal(t, "deny", responseData["behavior"])
|
||||
assert.Contains(t, responseData["message"], "Auto-denied")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no approval was created in database
|
||||
var count int
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(*) FROM approvals
|
||||
WHERE tool_use_id = ?
|
||||
`, toolUseID).Scan(&count)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count, "No approval should be created in auto-deny mode")
|
||||
})
|
||||
}
|
||||
@@ -1,260 +0,0 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/daemon"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMCPServerFullImplementation(t *testing.T) {
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-full")
|
||||
_ = testutil.DatabasePath(t, "mcp-full")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
os.Setenv("MCP_AUTO_DENY_ALL", "true") // Enable auto-deny for predictable testing
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create daemon
|
||||
d, err := daemon.New()
|
||||
require.NoError(t, err, "Failed to create daemon")
|
||||
|
||||
// Start daemon in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for HTTP server to be ready
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", httpPort)
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/v1/health", baseURL))
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 100*time.Millisecond, "HTTP server did not start")
|
||||
|
||||
t.Run("ToolsListSchemaValidation", func(t *testing.T) {
|
||||
// Test that tools/list returns proper schema structure
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "tools/list",
|
||||
"params": map[string]interface{}{},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
"application/json",
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate the tool schema structure
|
||||
res := result["result"].(map[string]interface{})
|
||||
tools := res["tools"].([]interface{})
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
tool := tools[0].(map[string]interface{})
|
||||
assert.Equal(t, "request_approval", tool["name"])
|
||||
assert.Equal(t, "Request permission to execute a tool", tool["description"])
|
||||
|
||||
// Check input schema structure
|
||||
inputSchema := tool["inputSchema"].(map[string]interface{})
|
||||
assert.Equal(t, "object", inputSchema["type"])
|
||||
|
||||
properties := inputSchema["properties"].(map[string]interface{})
|
||||
assert.Contains(t, properties, "tool_name")
|
||||
assert.Contains(t, properties, "input")
|
||||
assert.Contains(t, properties, "tool_use_id")
|
||||
|
||||
// Verify required fields
|
||||
required := inputSchema["required"].([]interface{})
|
||||
assert.Len(t, required, 3)
|
||||
assert.Contains(t, required, "tool_name")
|
||||
assert.Contains(t, required, "input")
|
||||
assert.Contains(t, required, "tool_use_id")
|
||||
|
||||
// Check annotations (mark3labs specific)
|
||||
if annotations, ok := tool["annotations"].(map[string]interface{}); ok {
|
||||
assert.NotNil(t, annotations["destructiveHint"])
|
||||
assert.NotNil(t, annotations["openWorldHint"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AutoDenyResponseStructure", func(t *testing.T) {
|
||||
// Test that auto-deny returns proper JSON structure
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "test_tool",
|
||||
"input": map[string]interface{}{"command": "ls -la"},
|
||||
"tool_use_id": "test_use_123",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
"application/json",
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate response structure
|
||||
res := result["result"].(map[string]interface{})
|
||||
content := res["content"].([]interface{})
|
||||
require.Len(t, content, 1)
|
||||
|
||||
contentItem := content[0].(map[string]interface{})
|
||||
assert.Equal(t, "text", contentItem["type"])
|
||||
|
||||
// Parse and validate the JSON in the text field
|
||||
text := contentItem["text"].(string)
|
||||
var approvalResponse map[string]interface{}
|
||||
err = json.Unmarshal([]byte(text), &approvalResponse)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "deny", approvalResponse["behavior"])
|
||||
assert.Equal(t, "Auto-denied for testing", approvalResponse["message"])
|
||||
})
|
||||
|
||||
t.Run("SessionIDHeaderExtraction", func(t *testing.T) {
|
||||
// Test that X-Session-ID header is properly handled
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
"tool_name": "test_with_session",
|
||||
"input": map[string]interface{}{"test": "data"},
|
||||
"tool_use_id": "session_test_456",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest("POST",
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
bytes.NewBuffer(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Session-ID", "test-session-789")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should still get auto-deny response (session ID doesn't affect auto-deny)
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we got a valid response (session ID was accepted)
|
||||
assert.Contains(t, result, "result")
|
||||
assert.NotContains(t, result, "error")
|
||||
})
|
||||
|
||||
t.Run("MissingRequiredFields", func(t *testing.T) {
|
||||
// Test that missing required fields return appropriate errors
|
||||
reqBody := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "request_approval",
|
||||
"arguments": map[string]interface{}{
|
||||
// Missing tool_use_id
|
||||
"tool_name": "incomplete_tool",
|
||||
"input": map[string]interface{}{"test": "data"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/api/v1/mcp", baseURL),
|
||||
"application/json",
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should still work in auto-deny mode (gets empty string for missing field)
|
||||
// but in real mode would be problematic
|
||||
if errField, hasError := result["error"]; hasError {
|
||||
// If there's an error, it should be about the missing field
|
||||
errMap := errField.(map[string]interface{})
|
||||
assert.Contains(t, errMap["message"], "required")
|
||||
} else {
|
||||
// In auto-deny mode, it might still process with empty tool_use_id
|
||||
res := result["result"].(map[string]interface{})
|
||||
assert.NotNil(t, res)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,282 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// HeaderCaptureServer captures all HTTP request headers for analysis
|
||||
type HeaderCaptureServer struct {
|
||||
mu sync.Mutex
|
||||
requests []RequestCapture
|
||||
}
|
||||
|
||||
type RequestCapture struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers map[string][]string
|
||||
Body []byte
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
func (s *HeaderCaptureServer) Handler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture request
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
capture := RequestCapture{
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
Headers: r.Header.Clone(),
|
||||
Body: body,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.requests = append(s.requests, capture)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Log headers
|
||||
fmt.Printf("\n[%s %s]\n", r.Method, r.URL.Path)
|
||||
fmt.Println("Headers:")
|
||||
for k, v := range r.Header {
|
||||
fmt.Printf(" %s: %s\n", k, strings.Join(v, ", "))
|
||||
}
|
||||
|
||||
// Parse JSON-RPC request
|
||||
var req map[string]interface{}
|
||||
json.Unmarshal(body, &req)
|
||||
|
||||
method, _ := req["method"].(string)
|
||||
id := req["id"]
|
||||
|
||||
// Create response
|
||||
response := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
}
|
||||
|
||||
// Handle MCP methods
|
||||
switch method {
|
||||
case "initialize":
|
||||
response["result"] = map[string]interface{}{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"serverInfo": map[string]interface{}{
|
||||
"name": "test-server",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"capabilities": map[string]interface{}{
|
||||
"tools": map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
case "tools/list":
|
||||
response["result"] = map[string]interface{}{
|
||||
"tools": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"inputSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case "tools/call":
|
||||
response["result"] = map[string]interface{}{
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "Test response",
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
response["error"] = map[string]interface{}{
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeaderCaptureServer) GetRequests() []RequestCapture {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return append([]RequestCapture{}, s.requests...)
|
||||
}
|
||||
|
||||
func TestMCPHeaderTransmission(t *testing.T) {
|
||||
// Skip if Claude CLI is not available
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("Claude CLI not available")
|
||||
}
|
||||
|
||||
// Create header capture server
|
||||
captureServer := &HeaderCaptureServer{}
|
||||
|
||||
// Start HTTP server
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.New()
|
||||
router.Any("/mcp", gin.WrapF(captureServer.Handler()))
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||
|
||||
server := &http.Server{Handler: router}
|
||||
go server.Serve(listener)
|
||||
defer server.Shutdown(context.Background())
|
||||
|
||||
// Create MCP config file with custom headers
|
||||
tempDir := t.TempDir()
|
||||
mcpConfigPath := fmt.Sprintf("%s/mcp-config.json", tempDir)
|
||||
|
||||
testSessionID := "test-session-123"
|
||||
mcpConfig := map[string]interface{}{
|
||||
"mcpServers": map[string]interface{}{
|
||||
"test": map[string]interface{}{
|
||||
"type": "http",
|
||||
"url": fmt.Sprintf("%s/mcp", baseURL),
|
||||
"headers": map[string]string{
|
||||
"X-Session-ID": testSessionID,
|
||||
"X-Custom-Header": "custom-value",
|
||||
"Authorization": "Bearer test-token",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
configBytes, _ := json.MarshalIndent(mcpConfig, "", " ")
|
||||
err = os.WriteFile(mcpConfigPath, configBytes, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("MCP Config:\n%s", string(configBytes))
|
||||
|
||||
// Launch Claude with MCP config
|
||||
cmd := exec.Command("claude",
|
||||
"--print", "test",
|
||||
"--mcp-config", mcpConfigPath,
|
||||
"--max-turns", "1",
|
||||
"--model", "sonnet",
|
||||
"--output-format", "json",
|
||||
)
|
||||
cmd.Dir = tempDir
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
t.Logf("Claude output:\n%s", string(output))
|
||||
|
||||
// Allow some time for any async operations
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Analyze captured requests
|
||||
requests := captureServer.GetRequests()
|
||||
t.Logf("\n=== Captured %d MCP Requests ===", len(requests))
|
||||
|
||||
sessionIDFound := false
|
||||
customHeaderFound := false
|
||||
authHeaderFound := false
|
||||
|
||||
for i, req := range requests {
|
||||
t.Logf("\nRequest #%d: %s %s", i+1, req.Method, req.Path)
|
||||
|
||||
// Check for our custom headers
|
||||
if sessionIDs, ok := req.Headers["X-Session-Id"]; ok && len(sessionIDs) > 0 {
|
||||
sessionIDFound = true
|
||||
sessionID := sessionIDs[0]
|
||||
t.Logf(" ✓ X-Session-ID: %s", sessionID)
|
||||
if sessionID != testSessionID {
|
||||
t.Errorf(" ✗ Session ID mismatch: got %s, want %s", sessionID, testSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
if customs, ok := req.Headers["X-Custom-Header"]; ok && len(customs) > 0 {
|
||||
customHeaderFound = true
|
||||
t.Logf(" ✓ X-Custom-Header: %s", customs[0])
|
||||
}
|
||||
|
||||
if auths, ok := req.Headers["Authorization"]; ok && len(auths) > 0 {
|
||||
authHeaderFound = true
|
||||
t.Logf(" ✓ Authorization: %s", auths[0])
|
||||
}
|
||||
|
||||
// Log all headers for debugging
|
||||
t.Log(" All headers:")
|
||||
for k, v := range req.Headers {
|
||||
t.Logf(" %s: %s", k, strings.Join(v, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// Verdict
|
||||
t.Log("\n=== VERDICT ===")
|
||||
if sessionIDFound && customHeaderFound && authHeaderFound {
|
||||
t.Log("✓ All custom headers are transmitted correctly")
|
||||
t.Log("✓ Session ID correlation via headers WILL work")
|
||||
} else {
|
||||
t.Log("✗ Custom headers are NOT being transmitted")
|
||||
if !sessionIDFound {
|
||||
t.Error("X-Session-ID header not found in MCP requests")
|
||||
}
|
||||
if !customHeaderFound {
|
||||
t.Error("X-Custom-Header not found in MCP requests")
|
||||
}
|
||||
if !authHeaderFound {
|
||||
t.Error("Authorization header not found in MCP requests")
|
||||
}
|
||||
|
||||
t.Log("\n=== RECOMMENDATIONS ===")
|
||||
t.Log("1. Embed session ID in the MCP server URL path")
|
||||
t.Log("2. Use unique MCP server instances per session")
|
||||
t.Log("3. Implement session correlation via MCP protocol messages")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPSessionCorrelationAlternatives(t *testing.T) {
|
||||
t.Log("\n=== Alternative Session Correlation Methods ===")
|
||||
|
||||
t.Log("\n1. URL Path Embedding:")
|
||||
t.Log(" - Change MCP endpoint to: /api/v1/mcp/:session_id")
|
||||
t.Log(" - Extract session ID from URL path in handler")
|
||||
t.Log(" - Pro: Simple, reliable, no header dependency")
|
||||
t.Log(" - Con: Requires URL generation per session")
|
||||
|
||||
t.Log("\n2. MCP Protocol Session:")
|
||||
t.Log(" - Use MCP's initialize response to establish session")
|
||||
t.Log(" - Store mapping: mcp_session_id -> humanlayer_session_id")
|
||||
t.Log(" - Pro: Protocol-native solution")
|
||||
t.Log(" - Con: Requires stateful MCP server")
|
||||
|
||||
t.Log("\n3. Token-based Correlation:")
|
||||
t.Log(" - Generate unique token per session")
|
||||
t.Log(" - Pass token in MCP server name or URL")
|
||||
t.Log(" - Pro: Secure, unique per session")
|
||||
t.Log(" - Con: Token management complexity")
|
||||
|
||||
t.Log("\n4. Process-based Correlation:")
|
||||
t.Log(" - Track Claude process ID")
|
||||
t.Log(" - Map process to session at launch")
|
||||
t.Log(" - Pro: OS-level tracking")
|
||||
t.Log(" - Con: Complex, platform-specific")
|
||||
}
|
||||
@@ -1,358 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/daemon"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMCPToolUseIDCorrelation verifies that when an approval is triggered
|
||||
// by a running Claude Code instance, the tool_use_id is properly set in the database
|
||||
func TestMCPToolUseIDCorrelation(t *testing.T) {
|
||||
// Setup isolated environment
|
||||
socketPath := testutil.SocketPath(t, "mcp-tool-use-id")
|
||||
dbPath := testutil.DatabasePath(t, "mcp-tool-use-id")
|
||||
|
||||
// Get a free port for HTTP server
|
||||
httpPort := getFreePort(t)
|
||||
|
||||
// Override environment
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_PORT", fmt.Sprintf("%d", httpPort))
|
||||
os.Setenv("HUMANLAYER_DAEMON_HTTP_HOST", "127.0.0.1")
|
||||
os.Setenv("HUMANLAYER_API_KEY", "") // Disable cloud API
|
||||
|
||||
// Create isolated config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
require.NoError(t, os.MkdirAll(configDir, 0755))
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
require.NoError(t, os.WriteFile(configFile, []byte(`{}`), 0644))
|
||||
|
||||
// Create daemon
|
||||
d, err := daemon.New()
|
||||
require.NoError(t, err, "Failed to create daemon")
|
||||
|
||||
// Start daemon in background
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for daemon to be ready
|
||||
require.Eventually(t, func() bool {
|
||||
// Check if the HTTP health endpoint is responding
|
||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/health", httpPort))
|
||||
if err == nil && resp != nil {
|
||||
resp.Body.Close()
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
return false
|
||||
}, 10*time.Second, 100*time.Millisecond, "Daemon did not start")
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// We'll use daemon's REST API to launch sessions properly
|
||||
|
||||
t.Run("SingleApprovalWithToolUseID", func(t *testing.T) {
|
||||
// Clear any existing approvals
|
||||
_, err = db.Exec("DELETE FROM approvals")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create temp directory for session
|
||||
testWorkDir := t.TempDir()
|
||||
|
||||
// Prepare session creation request for REST API
|
||||
createReq := map[string]interface{}{
|
||||
"query": "Write 'Hello World' to a file called test.txt and then exit",
|
||||
"model": "sonnet",
|
||||
"permission_prompt_tool": "mcp__codelayer__request_approval",
|
||||
"max_turns": 3,
|
||||
"working_dir": testWorkDir,
|
||||
"mcp_config": map[string]interface{}{
|
||||
"mcp_servers": map[string]interface{}{
|
||||
"codelayer": map[string]interface{}{
|
||||
"type": "http",
|
||||
"url": fmt.Sprintf("http://127.0.0.1:%d/api/v1/mcp", httpPort),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Send REST API request to create session
|
||||
reqBody, _ := json.Marshal(createReq)
|
||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d/api/v1/sessions", httpPort), bytes.NewBuffer(reqBody))
|
||||
require.NoError(t, err)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check response status
|
||||
require.Equal(t, http.StatusCreated, resp.StatusCode, "Expected 201 Created")
|
||||
|
||||
// Parse response
|
||||
var createResp map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&createResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get session ID from response
|
||||
data := createResp["data"].(map[string]interface{})
|
||||
sessionID := data["session_id"].(string)
|
||||
runID := data["run_id"].(string)
|
||||
t.Logf("Launched session: %s with run_id: %s", sessionID, runID)
|
||||
|
||||
// Let Claude run for a bit to trigger approvals
|
||||
t.Log("Waiting for Claude to trigger approvals...")
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Now check the database for approvals
|
||||
rows, err := db.Query(`
|
||||
SELECT id, session_id, tool_name, tool_use_id, status, comment
|
||||
FROM approvals
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var approvals []struct {
|
||||
ID string
|
||||
SessionID string
|
||||
ToolName string
|
||||
ToolUseID sql.NullString
|
||||
Status string
|
||||
Comment sql.NullString
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var a struct {
|
||||
ID string
|
||||
SessionID string
|
||||
ToolName string
|
||||
ToolUseID sql.NullString
|
||||
Status string
|
||||
Comment sql.NullString
|
||||
}
|
||||
err := rows.Scan(&a.ID, &a.SessionID, &a.ToolName, &a.ToolUseID, &a.Status, &a.Comment)
|
||||
require.NoError(t, err)
|
||||
approvals = append(approvals, a)
|
||||
}
|
||||
|
||||
// Log what we found
|
||||
t.Logf("Found %d approvals in database:", len(approvals))
|
||||
for i, a := range approvals {
|
||||
t.Logf(" Approval %d:", i+1)
|
||||
t.Logf(" ID: %s", a.ID)
|
||||
t.Logf(" Session ID: %s", a.SessionID)
|
||||
t.Logf(" Tool Name: %s", a.ToolName)
|
||||
t.Logf(" Tool Use ID: %v (Valid: %v)", a.ToolUseID.String, a.ToolUseID.Valid)
|
||||
t.Logf(" Status: %s", a.Status)
|
||||
if a.Comment.Valid {
|
||||
t.Logf(" Comment: %s", a.Comment.String)
|
||||
}
|
||||
}
|
||||
|
||||
// Also check conversation events for tool uses
|
||||
var toolUseCount int
|
||||
rows2, err := db.Query(`
|
||||
SELECT tool_id, tool_name
|
||||
FROM conversation_events
|
||||
WHERE session_id = ? AND tool_id IS NOT NULL
|
||||
ORDER BY created_at DESC
|
||||
`, sessionID)
|
||||
if err == nil {
|
||||
defer rows2.Close()
|
||||
for rows2.Next() {
|
||||
var toolID, toolName string
|
||||
if err := rows2.Scan(&toolID, &toolName); err == nil {
|
||||
toolUseCount++
|
||||
t.Logf(" Tool use in events: %s (ID: %s)", toolName, toolID)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Logf("Found %d tool uses in conversation_events", toolUseCount)
|
||||
|
||||
// Verify that we have at least one approval
|
||||
if len(approvals) > 0 {
|
||||
// Check that tool_use_id is set
|
||||
for _, a := range approvals {
|
||||
if !a.ToolUseID.Valid || a.ToolUseID.String == "" {
|
||||
t.Errorf("Approval %s has no tool_use_id set!", a.ID)
|
||||
} else {
|
||||
t.Logf("✓ Approval %s has tool_use_id: %s", a.ID, a.ToolUseID.String)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Log("No approvals were created - this might indicate the test didn't trigger any tools")
|
||||
t.Log("This can happen if Claude doesn't attempt to write the file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ParallelApprovalsWithDistinctToolUseIDs", func(t *testing.T) {
|
||||
// Clear any existing approvals
|
||||
_, err = db.Exec("DELETE FROM approvals")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create temp directory for session
|
||||
testWorkDir := t.TempDir()
|
||||
|
||||
// Prepare session creation request for REST API
|
||||
createReq := map[string]interface{}{
|
||||
"query": "Create 3 files in parallel: file1.txt with 'One', file2.txt with 'Two', file3.txt with 'Three'. Use parallel tool calls if possible.",
|
||||
"model": "sonnet",
|
||||
"permission_prompt_tool": "mcp__codelayer__request_approval",
|
||||
"max_turns": 3,
|
||||
"working_dir": testWorkDir,
|
||||
"mcp_config": map[string]interface{}{
|
||||
"mcp_servers": map[string]interface{}{
|
||||
"codelayer": map[string]interface{}{
|
||||
"type": "http",
|
||||
"url": fmt.Sprintf("http://127.0.0.1:%d/api/v1/mcp", httpPort),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Send REST API request to create session
|
||||
reqBody, _ := json.Marshal(createReq)
|
||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d/api/v1/sessions", httpPort), bytes.NewBuffer(reqBody))
|
||||
require.NoError(t, err)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check response status
|
||||
require.Equal(t, http.StatusCreated, resp.StatusCode, "Expected 201 Created")
|
||||
|
||||
// Parse response
|
||||
var createResp map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&createResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get session ID from response
|
||||
data := createResp["data"].(map[string]interface{})
|
||||
sessionID := data["session_id"].(string)
|
||||
t.Logf("Launched parallel session: %s", sessionID)
|
||||
|
||||
// Let Claude run for a bit
|
||||
t.Log("Waiting for parallel operations...")
|
||||
time.Sleep(7 * time.Second)
|
||||
|
||||
// Check database for approvals
|
||||
rows, err := db.Query(`
|
||||
SELECT id, tool_use_id, tool_name
|
||||
FROM approvals
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var approvals []struct {
|
||||
ID string
|
||||
ToolUseID sql.NullString
|
||||
ToolName string
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var a struct {
|
||||
ID string
|
||||
ToolUseID sql.NullString
|
||||
ToolName string
|
||||
}
|
||||
err := rows.Scan(&a.ID, &a.ToolUseID, &a.ToolName)
|
||||
require.NoError(t, err)
|
||||
approvals = append(approvals, a)
|
||||
}
|
||||
|
||||
t.Logf("Found %d approvals for parallel operations", len(approvals))
|
||||
|
||||
// Verify each approval has a unique tool_use_id
|
||||
toolUseIDMap := make(map[string]bool)
|
||||
for _, a := range approvals {
|
||||
if !a.ToolUseID.Valid || a.ToolUseID.String == "" {
|
||||
t.Errorf("Approval %s has no tool_use_id!", a.ID)
|
||||
} else {
|
||||
if toolUseIDMap[a.ToolUseID.String] {
|
||||
t.Errorf("Duplicate tool_use_id found: %s", a.ToolUseID.String)
|
||||
}
|
||||
toolUseIDMap[a.ToolUseID.String] = true
|
||||
t.Logf("✓ Approval %s has unique tool_use_id: %s", a.ID, a.ToolUseID.String)
|
||||
}
|
||||
}
|
||||
|
||||
// Cross-reference with conversation events
|
||||
var toolUseEvents []struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
rows2, err := db.Query(`
|
||||
SELECT tool_id, tool_name
|
||||
FROM conversation_events
|
||||
WHERE session_id = ? AND tool_id IS NOT NULL
|
||||
`, sessionID)
|
||||
if err == nil {
|
||||
defer rows2.Close()
|
||||
for rows2.Next() {
|
||||
var toolID, toolName string
|
||||
if err := rows2.Scan(&toolID, &toolName); err == nil {
|
||||
toolUseEvents = append(toolUseEvents, struct {
|
||||
ID string
|
||||
Name string
|
||||
}{ID: toolID, Name: toolName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Cross-referencing %d tool_use events with approvals", len(toolUseEvents))
|
||||
for _, toolUse := range toolUseEvents {
|
||||
found := false
|
||||
for _, a := range approvals {
|
||||
if a.ToolUseID.Valid && a.ToolUseID.String == toolUse.ID {
|
||||
found = true
|
||||
t.Logf("✓ Tool use %s matched with approval %s", toolUse.ID, a.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found && toolUse.ID != "" {
|
||||
t.Logf("⚠ Tool use %s (%s) has no matching approval", toolUse.ID, toolUse.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Cleanup: shutdown daemon
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && err != context.Canceled {
|
||||
t.Errorf("Daemon exited with error: %v", err)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Daemon did not shut down in time")
|
||||
}
|
||||
}
|
||||
@@ -11,13 +11,10 @@ require (
|
||||
github.com/getkin/kin-openapi v0.132.0
|
||||
github.com/gin-contrib/cors v1.7.6
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/humanlayer/humanlayer/claudecode-go v0.0.0-00010101000000-000000000000
|
||||
github.com/mark3labs/mcp-go v0.37.0
|
||||
github.com/mattn/go-sqlite3 v1.14.28
|
||||
github.com/oapi-codegen/runtime v1.1.2
|
||||
github.com/r3labs/sse/v2 v2.10.0
|
||||
github.com/spf13/viper v1.20.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.uber.org/mock v0.5.2
|
||||
@@ -25,8 +22,6 @@ require (
|
||||
|
||||
require (
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic v1.13.3 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.4 // indirect
|
||||
github.com/cloudwego/base64x v0.1.5 // indirect
|
||||
@@ -41,7 +36,6 @@ require (
|
||||
github.com/go-playground/validator/v10 v10.26.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
@@ -56,6 +50,7 @@ require (
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/r3labs/sse/v2 v2.10.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.12.0 // indirect
|
||||
@@ -64,8 +59,6 @@ require (
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.0 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.18.0 // indirect
|
||||
|
||||
34
hld/go.sum
34
hld/go.sum
@@ -1,11 +1,7 @@
|
||||
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bytedance/sonic v1.13.3 h1:MS8gmaH16Gtirygw7jV91pDCN33NyMrPbN7qiYhEsF0=
|
||||
github.com/bytedance/sonic v1.13.3/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
@@ -49,15 +45,11 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -75,8 +67,6 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mark3labs/mcp-go v0.37.0 h1:BywvZLPRT6Zx6mMG/MJfxLSZQkTGIcJSEGKsvr4DsoQ=
|
||||
github.com/mark3labs/mcp-go v0.37.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
@@ -133,11 +123,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
|
||||
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
||||
@@ -147,37 +132,18 @@ go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTV
|
||||
golang.org/x/arch v0.18.0 h1:WN9poc33zL4AzGxqf8VtpKUnGvMi8O9lhNyBMF/85qc=
|
||||
golang.org/x/arch v0.18.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y=
|
||||
|
||||
@@ -1,260 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// contextKey is the type for context keys
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
// sessionIDKey is the context key for session ID
|
||||
sessionIDKey contextKey = "session_id"
|
||||
)
|
||||
|
||||
// ApprovalDecision represents the outcome of an approval request
|
||||
type ApprovalDecision struct {
|
||||
Approved bool
|
||||
Comment string
|
||||
}
|
||||
|
||||
// MCPServer wraps the mark3labs MCP server
|
||||
type MCPServer struct {
|
||||
mcpServer *server.MCPServer
|
||||
httpServer *server.StreamableHTTPServer
|
||||
approvalManager approval.Manager
|
||||
eventBus bus.EventBus
|
||||
autoDenyAll bool
|
||||
pendingApprovals sync.Map // map[string]chan ApprovalDecision
|
||||
}
|
||||
|
||||
// NewMCPServer creates the full MCP server implementation
|
||||
func NewMCPServer(approvalManager approval.Manager, eventBus bus.EventBus) *MCPServer {
|
||||
autoDeny := os.Getenv("MCP_AUTO_DENY_ALL") == "true"
|
||||
|
||||
s := &MCPServer{
|
||||
approvalManager: approvalManager,
|
||||
eventBus: eventBus,
|
||||
autoDenyAll: autoDeny,
|
||||
}
|
||||
|
||||
// Create MCP server
|
||||
s.mcpServer = server.NewMCPServer(
|
||||
"humanlayer-daemon",
|
||||
"1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
// Add request_approval tool
|
||||
s.mcpServer.AddTool(
|
||||
mcp.NewTool("request_approval",
|
||||
mcp.WithDescription("Request permission to execute a tool"),
|
||||
mcp.WithString("tool_name",
|
||||
mcp.Description("The name of the tool requesting permission"),
|
||||
mcp.Required(),
|
||||
),
|
||||
mcp.WithObject("input",
|
||||
mcp.Description("The input to the tool"),
|
||||
mcp.Required(),
|
||||
),
|
||||
mcp.WithString("tool_use_id",
|
||||
mcp.Description("Unique identifier for this tool use"),
|
||||
mcp.Required(),
|
||||
),
|
||||
),
|
||||
s.handleRequestApproval,
|
||||
)
|
||||
|
||||
// Create HTTP server (stateless for now)
|
||||
s.httpServer = server.NewStreamableHTTPServer(
|
||||
s.mcpServer,
|
||||
server.WithStateLess(true),
|
||||
)
|
||||
|
||||
// Don't start goroutine here - wait for Start() to be called
|
||||
return s
|
||||
}
|
||||
|
||||
// Start initializes the MCP server's background processes
|
||||
func (s *MCPServer) Start(ctx context.Context) {
|
||||
if s.eventBus != nil {
|
||||
go s.listenForApprovalDecisions(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MCPServer) handleRequestApproval(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
toolName := request.GetString("tool_name", "")
|
||||
input := request.GetArguments()["input"]
|
||||
toolUseID := request.GetString("tool_use_id", "")
|
||||
|
||||
slog.Info("MCP approval requested",
|
||||
"tool_name", toolName,
|
||||
"tool_use_id", toolUseID,
|
||||
"auto_deny", s.autoDenyAll)
|
||||
|
||||
// Auto-deny takes precedence
|
||||
if s.autoDenyAll {
|
||||
slog.Info("Auto-denying approval", "tool_use_id", toolUseID)
|
||||
|
||||
responseData := map[string]interface{}{
|
||||
"behavior": "deny",
|
||||
"message": "Auto-denied for testing",
|
||||
}
|
||||
responseJSON, _ := json.Marshal(responseData)
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(responseJSON),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get session_id from context
|
||||
sessionID, _ := ctx.Value(sessionIDKey).(string)
|
||||
if sessionID == "" {
|
||||
return nil, fmt.Errorf("missing session_id in context")
|
||||
}
|
||||
|
||||
// Marshal input to JSON
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal input: %w", err)
|
||||
}
|
||||
|
||||
// Create approval with tool_use_id
|
||||
approval, err := s.approvalManager.CreateApprovalWithToolUseID(ctx, sessionID, toolName, inputJSON, toolUseID)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create approval", "error", err)
|
||||
return nil, fmt.Errorf("failed to create approval: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("Created approval", "approval_id", approval.ID, "status", approval.Status)
|
||||
|
||||
// Check if the approval was auto-approved
|
||||
if approval.Status == "approved" {
|
||||
// Return allow behavior for auto-approved
|
||||
responseData := map[string]interface{}{
|
||||
"behavior": "allow",
|
||||
"updatedInput": input,
|
||||
}
|
||||
responseJSON, _ := json.Marshal(responseData)
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(responseJSON),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Register for event-driven approval resolution
|
||||
decisionChan := make(chan ApprovalDecision, 1)
|
||||
s.pendingApprovals.Store(toolUseID, decisionChan)
|
||||
defer s.pendingApprovals.Delete(toolUseID)
|
||||
|
||||
// Wait for approval decision
|
||||
select {
|
||||
case decision := <-decisionChan:
|
||||
responseData := map[string]interface{}{
|
||||
"behavior": "deny",
|
||||
"message": decision.Comment,
|
||||
}
|
||||
if decision.Approved {
|
||||
responseData = map[string]interface{}{
|
||||
"behavior": "allow",
|
||||
"updatedInput": input,
|
||||
}
|
||||
}
|
||||
responseJSON, _ := json.Marshal(responseData)
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(responseJSON),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
// For the moment, we don't timeout approvals, but in the future
|
||||
// may choose to add a timeout or determine otherwise for resumed sessions
|
||||
// case <-time.After(5 * time.Minute):
|
||||
// return nil, fmt.Errorf("approval timeout")
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MCPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract session_id from header and add to context
|
||||
sessionID := r.Header.Get("X-Session-ID")
|
||||
if sessionID == "" {
|
||||
// Try to extract from MCP session if available
|
||||
mcpSessionID := r.Header.Get("Mcp-Session-Id")
|
||||
if mcpSessionID != "" {
|
||||
sessionID = mcpSessionID
|
||||
}
|
||||
}
|
||||
|
||||
// Add session_id to context for future use
|
||||
ctx := context.WithValue(r.Context(), sessionIDKey, sessionID)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
s.httpServer.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// listenForApprovalDecisions listens for approval resolution events and notifies waiting handlers
|
||||
func (s *MCPServer) listenForApprovalDecisions(ctx context.Context) {
|
||||
sub := s.eventBus.Subscribe(ctx, bus.EventFilter{
|
||||
Types: []bus.EventType{bus.EventApprovalResolved},
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("MCP approval listener shutting down")
|
||||
return
|
||||
case event, ok := <-sub.Channel:
|
||||
if !ok {
|
||||
slog.Info("MCP approval listener channel closed")
|
||||
return
|
||||
}
|
||||
toolUseID, _ := event.Data["tool_use_id"].(string)
|
||||
approved, _ := event.Data["approved"].(bool)
|
||||
comment, _ := event.Data["response_text"].(string)
|
||||
|
||||
if toolUseID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find pending approval channel
|
||||
if ch, ok := s.pendingApprovals.Load(toolUseID); ok {
|
||||
select {
|
||||
case ch.(chan ApprovalDecision) <- ApprovalDecision{
|
||||
Approved: approved,
|
||||
Comment: comment,
|
||||
}:
|
||||
slog.Info("Sent approval decision", "tool_use_id", toolUseID, "approved", approved)
|
||||
default:
|
||||
slog.Warn("Channel full or closed", "tool_use_id", toolUseID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,47 +20,30 @@ import { mapValues } from '../runtime';
|
||||
*/
|
||||
export interface MCPServer {
|
||||
/**
|
||||
* Server type (http for HTTP servers, omit for stdio)
|
||||
* Command to execute
|
||||
* @type {string}
|
||||
* @memberof MCPServer
|
||||
*/
|
||||
type?: string;
|
||||
command: string;
|
||||
/**
|
||||
* Command to execute (for stdio servers)
|
||||
* @type {string}
|
||||
* @memberof MCPServer
|
||||
*/
|
||||
command?: string;
|
||||
/**
|
||||
* Command arguments (for stdio servers)
|
||||
* Command arguments
|
||||
* @type {Array<string>}
|
||||
* @memberof MCPServer
|
||||
*/
|
||||
args?: Array<string>;
|
||||
/**
|
||||
* Environment variables (for stdio servers)
|
||||
* Environment variables
|
||||
* @type {{ [key: string]: string; }}
|
||||
* @memberof MCPServer
|
||||
*/
|
||||
env?: { [key: string]: string; };
|
||||
/**
|
||||
* HTTP endpoint URL (for HTTP servers)
|
||||
* @type {string}
|
||||
* @memberof MCPServer
|
||||
*/
|
||||
url?: string;
|
||||
/**
|
||||
* HTTP headers to include (for HTTP servers)
|
||||
* @type {{ [key: string]: string; }}
|
||||
* @memberof MCPServer
|
||||
*/
|
||||
headers?: { [key: string]: string; };
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a given object implements the MCPServer interface.
|
||||
*/
|
||||
export function instanceOfMCPServer(value: object): value is MCPServer {
|
||||
if (!('command' in value) || value['command'] === undefined) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -74,12 +57,9 @@ export function MCPServerFromJSONTyped(json: any, ignoreDiscriminator: boolean):
|
||||
}
|
||||
return {
|
||||
|
||||
'type': json['type'] == null ? undefined : json['type'],
|
||||
'command': json['command'] == null ? undefined : json['command'],
|
||||
'command': json['command'],
|
||||
'args': json['args'] == null ? undefined : json['args'],
|
||||
'env': json['env'] == null ? undefined : json['env'],
|
||||
'url': json['url'] == null ? undefined : json['url'],
|
||||
'headers': json['headers'] == null ? undefined : json['headers'],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -94,11 +74,8 @@ export function MCPServerToJSONTyped(value?: MCPServer | null, ignoreDiscriminat
|
||||
|
||||
return {
|
||||
|
||||
'type': value['type'],
|
||||
'command': value['command'],
|
||||
'args': value['args'],
|
||||
'env': value['env'],
|
||||
'url': value['url'],
|
||||
'headers': value['headers'],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -779,123 +779,4 @@ func TestContinueSessionInheritance(t *testing.T) {
|
||||
t.Errorf("Child didn't inherit grandparent title: got %q, want %q", childSession.Title, grandparentTitle)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTPMCPServerUpdatesXSessionIDHeader", func(t *testing.T) {
|
||||
// This test verifies that when continuing a session with HTTP MCP servers,
|
||||
// the X-Session-ID header is updated to the child session ID, not inherited
|
||||
// from the parent session ID.
|
||||
|
||||
// Create parent session
|
||||
parentSessionID := "parent-http-mcp"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-http-mcp",
|
||||
ClaudeSessionID: "claude-http-mcp",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "http mcp query",
|
||||
Model: "claude-3-opus-20240229",
|
||||
WorkingDir: "/tmp/test",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Store HTTP MCP server with X-Session-ID header for parent
|
||||
// For HTTP servers: Command="http", ArgsJSON=["URL"], EnvJSON=headers
|
||||
parentMCPServers := []store.MCPServer{
|
||||
{
|
||||
SessionID: parentSessionID,
|
||||
Name: "http-test-server",
|
||||
Command: "http", // Indicates HTTP type
|
||||
ArgsJSON: `["http://localhost:8080/mcp"]`, // URL as single-element array
|
||||
EnvJSON: `{"X-Session-ID": "parent-http-mcp", "Authorization": "Bearer token123"}`, // Headers
|
||||
},
|
||||
}
|
||||
if err := sqliteStore.StoreMCPServers(ctx, parentSessionID, parentMCPServers); err != nil {
|
||||
t.Fatalf("Failed to store MCP servers: %v", err)
|
||||
}
|
||||
|
||||
// Continue session
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: parentSessionID,
|
||||
Query: "continue http mcp",
|
||||
}
|
||||
|
||||
_, _ = manager.ContinueSession(ctx, req)
|
||||
// Expected to fail due to missing Claude binary
|
||||
|
||||
// Find the child session
|
||||
sessions, err := sqliteStore.ListSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list sessions: %v", err)
|
||||
}
|
||||
|
||||
var childSession *store.Session
|
||||
for _, s := range sessions {
|
||||
if s.ParentSessionID == parentSessionID {
|
||||
childSession = s
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if childSession == nil {
|
||||
t.Fatal("Child session not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Get MCP servers for child session
|
||||
childMCPServers, err := sqliteStore.GetMCPServers(ctx, childSession.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get child MCP servers: %v", err)
|
||||
}
|
||||
|
||||
// Should have inherited the MCP server
|
||||
if len(childMCPServers) != 1 {
|
||||
t.Fatalf("Expected 1 MCP server, got %d", len(childMCPServers))
|
||||
}
|
||||
|
||||
childMCPServer := childMCPServers[0]
|
||||
|
||||
// Verify basic inheritance
|
||||
if childMCPServer.Name != "http-test-server" {
|
||||
t.Errorf("MCP server name not inherited: got %s, want http-test-server", childMCPServer.Name)
|
||||
}
|
||||
if childMCPServer.Command != "http" {
|
||||
t.Errorf("MCP server type not inherited: got %s, want http", childMCPServer.Command)
|
||||
}
|
||||
|
||||
// Verify URL was inherited (stored in ArgsJSON)
|
||||
var childArgs []string
|
||||
if err := json.Unmarshal([]byte(childMCPServer.ArgsJSON), &childArgs); err != nil {
|
||||
t.Fatalf("Failed to unmarshal child args: %v", err)
|
||||
}
|
||||
if len(childArgs) != 1 || childArgs[0] != "http://localhost:8080/mcp" {
|
||||
t.Errorf("MCP server URL not inherited: got %v, want [http://localhost:8080/mcp]", childArgs)
|
||||
}
|
||||
|
||||
// Parse headers (stored in EnvJSON) and verify X-Session-ID was updated
|
||||
var childHeaders map[string]string
|
||||
if err := json.Unmarshal([]byte(childMCPServer.EnvJSON), &childHeaders); err != nil {
|
||||
t.Fatalf("Failed to unmarshal child headers: %v", err)
|
||||
}
|
||||
|
||||
// CRITICAL: X-Session-ID should be the CHILD session ID, not the parent's
|
||||
if xSessionID, ok := childHeaders["X-Session-ID"]; !ok {
|
||||
t.Error("X-Session-ID header missing in child MCP server")
|
||||
} else if xSessionID != childSession.ID {
|
||||
t.Errorf("X-Session-ID not updated to child session ID: got %s, want %s", xSessionID, childSession.ID)
|
||||
t.Log("This is the bug! X-Session-ID should be replaced with the child session ID")
|
||||
}
|
||||
|
||||
// Other headers should be preserved
|
||||
if auth, ok := childHeaders["Authorization"]; !ok {
|
||||
t.Error("Authorization header not inherited")
|
||||
} else if auth != "Bearer token123" {
|
||||
t.Errorf("Authorization header value changed: got %s, want Bearer token123", auth)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -69,42 +69,24 @@ func (m *Manager) LaunchSession(ctx context.Context, config LaunchSessionConfig)
|
||||
claudeConfig := config.SessionConfig
|
||||
|
||||
// Add HUMANLAYER_RUN_ID and HUMANLAYER_DAEMON_SOCKET to MCP server environment
|
||||
// For HTTP servers, inject session ID header
|
||||
if claudeConfig.MCPConfig != nil {
|
||||
slog.Debug("configuring MCP servers", "count", len(claudeConfig.MCPConfig.MCPServers))
|
||||
for name, server := range claudeConfig.MCPConfig.MCPServers {
|
||||
// Check if this is an HTTP MCP server
|
||||
if server.Type == "http" {
|
||||
// For HTTP servers, inject session ID header if not already set
|
||||
if server.Headers == nil {
|
||||
server.Headers = make(map[string]string)
|
||||
}
|
||||
// Only inject if not already set (allow override)
|
||||
if _, exists := server.Headers["X-Session-ID"]; !exists {
|
||||
server.Headers["X-Session-ID"] = sessionID
|
||||
}
|
||||
slog.Debug("configured HTTP MCP server",
|
||||
"name", name,
|
||||
"url", server.URL,
|
||||
"session_id", sessionID)
|
||||
} else {
|
||||
// For stdio servers, add environment variables
|
||||
if server.Env == nil {
|
||||
server.Env = make(map[string]string)
|
||||
}
|
||||
server.Env["HUMANLAYER_RUN_ID"] = runID
|
||||
// Add daemon socket path so MCP servers connect to the correct daemon
|
||||
if m.socketPath != "" {
|
||||
server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath
|
||||
}
|
||||
slog.Debug("configured stdio MCP server",
|
||||
"name", name,
|
||||
"command", server.Command,
|
||||
"args", server.Args,
|
||||
"run_id", runID,
|
||||
"socket_path", m.socketPath)
|
||||
if server.Env == nil {
|
||||
server.Env = make(map[string]string)
|
||||
}
|
||||
server.Env["HUMANLAYER_RUN_ID"] = runID
|
||||
// Add daemon socket path so MCP servers connect to the correct daemon
|
||||
if m.socketPath != "" {
|
||||
server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath
|
||||
}
|
||||
claudeConfig.MCPConfig.MCPServers[name] = server
|
||||
slog.Debug("configured MCP server",
|
||||
"name", name,
|
||||
"command", server.Command,
|
||||
"args", server.Args,
|
||||
"run_id", runID,
|
||||
"socket_path", m.socketPath)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("no MCP config provided")
|
||||
@@ -160,11 +142,7 @@ func (m *Manager) LaunchSession(ctx context.Context, config LaunchSessionConfig)
|
||||
if claudeConfig.MCPConfig != nil {
|
||||
mcpServerCount = len(claudeConfig.MCPConfig.MCPServers)
|
||||
for name, server := range claudeConfig.MCPConfig.MCPServers {
|
||||
if server.Type == "http" {
|
||||
mcpServersDetail += fmt.Sprintf("[%s: type=http url=%s headers=%v] ", name, server.URL, server.Headers)
|
||||
} else {
|
||||
mcpServersDetail += fmt.Sprintf("[%s: cmd=%s args=%v env=%v] ", name, server.Command, server.Args, server.Env)
|
||||
}
|
||||
mcpServersDetail += fmt.Sprintf("[%s: cmd=%s args=%v env=%v] ", name, server.Command, server.Args, server.Env)
|
||||
}
|
||||
}
|
||||
slog.Info("launching Claude session with configuration",
|
||||
@@ -236,15 +214,8 @@ func (m *Manager) LaunchSession(ctx context.Context, config LaunchSessionConfig)
|
||||
// Reconcile any existing approvals for this run_id
|
||||
if m.approvalReconciler != nil {
|
||||
go func() {
|
||||
// Give the session a moment to start (with cancellation support)
|
||||
select {
|
||||
case <-time.After(2 * time.Second):
|
||||
// Continue with reconciliation
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, exit early
|
||||
return
|
||||
}
|
||||
|
||||
// Give the session a moment to start
|
||||
time.Sleep(2 * time.Second)
|
||||
if err := m.approvalReconciler.ReconcileApprovalsForSession(ctx, runID); err != nil {
|
||||
slog.Error("failed to reconcile approvals for session",
|
||||
"session_id", sessionID,
|
||||
@@ -1212,24 +1183,10 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig
|
||||
env = map[string]string{}
|
||||
}
|
||||
|
||||
// Check if this is an HTTP server (stored with command="http")
|
||||
if server.Command == "http" {
|
||||
// HTTP server - extract URL from args and headers from env
|
||||
var urls []string
|
||||
if err := json.Unmarshal([]byte(server.ArgsJSON), &urls); err == nil && len(urls) > 0 {
|
||||
config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{
|
||||
Type: "http",
|
||||
URL: urls[0],
|
||||
Headers: env, // Headers were stored in EnvJSON
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Traditional stdio server
|
||||
config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{
|
||||
Command: server.Command,
|
||||
Args: args,
|
||||
Env: env,
|
||||
}
|
||||
config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{
|
||||
Command: server.Command,
|
||||
Args: args,
|
||||
Env: env,
|
||||
}
|
||||
}
|
||||
slog.Debug("inherited MCP servers from parent session",
|
||||
@@ -1298,28 +1255,15 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig
|
||||
}
|
||||
|
||||
// Add run_id and daemon socket to MCP server environments
|
||||
// For HTTP servers, inject session ID header
|
||||
|
||||
if config.MCPConfig != nil {
|
||||
for name, server := range config.MCPConfig.MCPServers {
|
||||
// Check if this is an HTTP MCP server
|
||||
if server.Type == "http" {
|
||||
// For HTTP servers, always set session ID header to child session ID
|
||||
if server.Headers == nil {
|
||||
server.Headers = make(map[string]string)
|
||||
}
|
||||
// Always set X-Session-ID to the new child session ID (replaces inherited parent ID)
|
||||
server.Headers["X-Session-ID"] = sessionID
|
||||
} else {
|
||||
// For stdio servers, add environment variables
|
||||
if server.Env == nil {
|
||||
server.Env = make(map[string]string)
|
||||
}
|
||||
server.Env["HUMANLAYER_RUN_ID"] = runID
|
||||
// Add daemon socket path so MCP servers connect to the correct daemon
|
||||
if m.socketPath != "" {
|
||||
server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath
|
||||
}
|
||||
if server.Env == nil {
|
||||
server.Env = make(map[string]string)
|
||||
}
|
||||
server.Env["HUMANLAYER_RUN_ID"] = runID
|
||||
// Add daemon socket path so MCP servers connect to the correct daemon
|
||||
if m.socketPath != "" {
|
||||
server.Env["HUMANLAYER_DAEMON_SOCKET"] = m.socketPath
|
||||
}
|
||||
config.MCPConfig.MCPServers[name] = server
|
||||
}
|
||||
@@ -1395,15 +1339,8 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig
|
||||
// Reconcile any existing approvals for this run_id (same run_id is reused for continuations)
|
||||
if m.approvalReconciler != nil {
|
||||
go func() {
|
||||
// Give the session a moment to start (with cancellation support)
|
||||
select {
|
||||
case <-time.After(2 * time.Second):
|
||||
// Continue with reconciliation
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, exit early
|
||||
return
|
||||
}
|
||||
|
||||
// Give the session a moment to start
|
||||
time.Sleep(2 * time.Second)
|
||||
if err := m.approvalReconciler.ReconcileApprovalsForSession(ctx, runID); err != nil {
|
||||
slog.Error("failed to reconcile approvals for continued session",
|
||||
"session_id", sessionID,
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigration14_ToolUseID(t *testing.T) {
|
||||
// Create an in-memory database for testing
|
||||
s, err := store.NewSQLiteStore(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = s.Close() }()
|
||||
|
||||
// First create a session to satisfy foreign key constraint
|
||||
session := &store.Session{
|
||||
ID: "test-session-1",
|
||||
RunID: "test-run-1",
|
||||
Query: "test query",
|
||||
Status: store.SessionStatusRunning,
|
||||
}
|
||||
err = s.CreateSession(context.Background(), session)
|
||||
require.NoError(t, err, "Should be able to create session")
|
||||
|
||||
// Create a test approval with tool_use_id
|
||||
toolUseID := "test-tool-use-id-123"
|
||||
approval := &store.Approval{
|
||||
ID: "test-approval-1",
|
||||
RunID: "test-run-1",
|
||||
SessionID: "test-session-1",
|
||||
ToolUseID: &toolUseID,
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
ToolName: "test-tool",
|
||||
ToolInput: []byte(`{"test": "data"}`),
|
||||
}
|
||||
|
||||
// Create the approval
|
||||
err = s.CreateApproval(context.Background(), approval)
|
||||
require.NoError(t, err, "Should be able to create approval with tool_use_id")
|
||||
|
||||
// Retrieve the approval
|
||||
retrieved, err := s.GetApproval(context.Background(), "test-approval-1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrieved)
|
||||
|
||||
// Verify tool_use_id was saved and retrieved correctly
|
||||
assert.NotNil(t, retrieved.ToolUseID, "ToolUseID should not be nil")
|
||||
if retrieved.ToolUseID != nil {
|
||||
assert.Equal(t, toolUseID, *retrieved.ToolUseID, "ToolUseID should match")
|
||||
}
|
||||
|
||||
// Create another session for the second approval
|
||||
session2 := &store.Session{
|
||||
ID: "test-session-2",
|
||||
RunID: "test-run-2",
|
||||
Query: "test query 2",
|
||||
Status: store.SessionStatusRunning,
|
||||
}
|
||||
err = s.CreateSession(context.Background(), session2)
|
||||
require.NoError(t, err, "Should be able to create second session")
|
||||
|
||||
// Test creating approval without tool_use_id (nullable field)
|
||||
approval2 := &store.Approval{
|
||||
ID: "test-approval-2",
|
||||
RunID: "test-run-2",
|
||||
SessionID: "test-session-2",
|
||||
ToolUseID: nil, // Explicitly nil
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
ToolName: "test-tool-2",
|
||||
ToolInput: []byte(`{"test": "data2"}`),
|
||||
}
|
||||
|
||||
err = s.CreateApproval(context.Background(), approval2)
|
||||
require.NoError(t, err, "Should be able to create approval without tool_use_id")
|
||||
|
||||
// Retrieve and verify it's nil
|
||||
retrieved2, err := s.GetApproval(context.Background(), "test-approval-2")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, retrieved2.ToolUseID, "ToolUseID should be nil when not provided")
|
||||
}
|
||||
@@ -695,75 +695,6 @@ func (s *SQLiteStore) applyMigrations() error {
|
||||
slog.Info("Migration 13 applied successfully")
|
||||
}
|
||||
|
||||
// Migration 14: Add tool_use_id column to approvals table
|
||||
if currentVersion < 14 {
|
||||
slog.Info("Applying migration 14: Add tool_use_id column to approvals table")
|
||||
|
||||
// Check if column already exists for idempotency
|
||||
var columnExists int
|
||||
err = s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM pragma_table_info('approvals')
|
||||
WHERE name = 'tool_use_id'
|
||||
`).Scan(&columnExists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check tool_use_id column: %w", err)
|
||||
}
|
||||
|
||||
// Only add column if it doesn't exist
|
||||
if columnExists == 0 {
|
||||
_, err = s.db.Exec(`
|
||||
ALTER TABLE approvals
|
||||
ADD COLUMN tool_use_id TEXT
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add tool_use_id column: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create index for efficient lookups
|
||||
_, err = s.db.Exec(`
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_tool_use_id
|
||||
ON approvals(tool_use_id)
|
||||
WHERE tool_use_id IS NOT NULL
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tool_use_id index: %w", err)
|
||||
}
|
||||
|
||||
// Update existing approvals to populate tool_use_id from correlated events
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE approvals
|
||||
SET tool_use_id = (
|
||||
SELECT ce.tool_id
|
||||
FROM conversation_events ce
|
||||
WHERE ce.approval_id = approvals.id
|
||||
AND ce.tool_id IS NOT NULL
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM conversation_events ce
|
||||
WHERE ce.approval_id = approvals.id
|
||||
AND ce.tool_id IS NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to populate tool_use_id for existing approvals (non-critical)", "error", err)
|
||||
// This is non-critical as existing approvals may not have correlation
|
||||
}
|
||||
|
||||
// Record migration
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO schema_version (version, description)
|
||||
VALUES (14, 'Add tool_use_id column to approvals table for direct correlation')
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration 14: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("Migration 14 applied successfully")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1825,8 +1756,8 @@ func (s *SQLiteStore) CorrelateApproval(ctx context.Context, sessionID string, t
|
||||
return nil
|
||||
}
|
||||
|
||||
// LinkConversationEventToApprovalUsingToolID correlates an approval with a specific tool call by tool_id
|
||||
func (s *SQLiteStore) LinkConversationEventToApprovalUsingToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error {
|
||||
// CorrelateApprovalByToolID correlates an approval with a specific tool call by tool_id
|
||||
func (s *SQLiteStore) CorrelateApprovalByToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error {
|
||||
// Update the tool call directly by tool_id
|
||||
updateQuery := `
|
||||
UPDATE conversation_events
|
||||
@@ -1969,14 +1900,14 @@ func (s *SQLiteStore) CreateApproval(ctx context.Context, approval *Approval) er
|
||||
|
||||
query := `
|
||||
INSERT INTO approvals (
|
||||
id, run_id, session_id, tool_use_id, status, created_at,
|
||||
tool_name, tool_input, comment
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
id, run_id, session_id, status, created_at,
|
||||
tool_name, tool_input
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query,
|
||||
approval.ID, approval.RunID, approval.SessionID, approval.ToolUseID, approval.Status.String(), approval.CreatedAt,
|
||||
approval.ToolName, string(approval.ToolInput), approval.Comment,
|
||||
approval.ID, approval.RunID, approval.SessionID, approval.Status.String(), approval.CreatedAt,
|
||||
approval.ToolName, string(approval.ToolInput),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create approval: %w", err)
|
||||
@@ -1987,20 +1918,19 @@ func (s *SQLiteStore) CreateApproval(ctx context.Context, approval *Approval) er
|
||||
// GetApproval retrieves an approval by ID
|
||||
func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, error) {
|
||||
query := `
|
||||
SELECT id, run_id, session_id, tool_use_id, status, created_at, responded_at,
|
||||
SELECT id, run_id, session_id, status, created_at, responded_at,
|
||||
tool_name, tool_input, comment
|
||||
FROM approvals WHERE id = ?
|
||||
`
|
||||
|
||||
var approval Approval
|
||||
var toolUseID sql.NullString
|
||||
var respondedAt sql.NullTime
|
||||
var comment sql.NullString
|
||||
var statusStr string
|
||||
var toolInputStr string
|
||||
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&approval.ID, &approval.RunID, &approval.SessionID, &toolUseID, &statusStr,
|
||||
&approval.ID, &approval.RunID, &approval.SessionID, &statusStr,
|
||||
&approval.CreatedAt, &respondedAt,
|
||||
&approval.ToolName, &toolInputStr, &comment,
|
||||
)
|
||||
@@ -2018,9 +1948,6 @@ func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, er
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if toolUseID.Valid {
|
||||
approval.ToolUseID = &toolUseID.String
|
||||
}
|
||||
if respondedAt.Valid {
|
||||
approval.RespondedAt = &respondedAt.Time
|
||||
}
|
||||
@@ -2033,7 +1960,7 @@ func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, er
|
||||
// GetPendingApprovals retrieves all pending approvals for a session
|
||||
func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string) ([]*Approval, error) {
|
||||
query := `
|
||||
SELECT id, run_id, session_id, tool_use_id, status, created_at, responded_at,
|
||||
SELECT id, run_id, session_id, status, created_at, responded_at,
|
||||
tool_name, tool_input, comment
|
||||
FROM approvals
|
||||
WHERE session_id = ? AND status = ?
|
||||
@@ -2049,14 +1976,13 @@ func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string)
|
||||
var approvals []*Approval
|
||||
for rows.Next() {
|
||||
var approval Approval
|
||||
var toolUseID sql.NullString
|
||||
var respondedAt sql.NullTime
|
||||
var comment sql.NullString
|
||||
var statusStr string
|
||||
var toolInputStr string
|
||||
|
||||
err := rows.Scan(
|
||||
&approval.ID, &approval.RunID, &approval.SessionID, &toolUseID, &statusStr,
|
||||
&approval.ID, &approval.RunID, &approval.SessionID, &statusStr,
|
||||
&approval.CreatedAt, &respondedAt,
|
||||
&approval.ToolName, &toolInputStr, &comment,
|
||||
)
|
||||
@@ -2071,9 +1997,6 @@ func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if toolUseID.Valid {
|
||||
approval.ToolUseID = &toolUseID.String
|
||||
}
|
||||
if respondedAt.Valid {
|
||||
approval.RespondedAt = &respondedAt.Time
|
||||
}
|
||||
@@ -2141,47 +2064,22 @@ func MCPServersFromConfig(sessionID string, config map[string]claudecode.MCPServ
|
||||
servers := make([]MCPServer, 0, len(config))
|
||||
for _, name := range names {
|
||||
server := config[name]
|
||||
argsJSON, err := json.Marshal(server.Args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal args: %w", err)
|
||||
}
|
||||
|
||||
// For HTTP servers, store the configuration differently
|
||||
// We'll use Command field to store the type, ArgsJSON for URL, and EnvJSON for headers
|
||||
var command string
|
||||
var argsJSON string
|
||||
var envJSON string
|
||||
|
||||
if server.Type == "http" {
|
||||
// HTTP server
|
||||
command = "http" // Use "http" as the command to indicate HTTP type
|
||||
argsJSON = fmt.Sprintf(`["%s"]`, server.URL) // Store URL as single-element array
|
||||
|
||||
// Store headers in EnvJSON
|
||||
headersData, err := json.Marshal(server.Headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal headers: %w", err)
|
||||
}
|
||||
envJSON = string(headersData)
|
||||
} else {
|
||||
// Traditional stdio server
|
||||
command = server.Command
|
||||
|
||||
argsData, err := json.Marshal(server.Args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal args: %w", err)
|
||||
}
|
||||
argsJSON = string(argsData)
|
||||
|
||||
envData, err := json.Marshal(server.Env)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal env: %w", err)
|
||||
}
|
||||
envJSON = string(envData)
|
||||
envJSON, err := json.Marshal(server.Env)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal env: %w", err)
|
||||
}
|
||||
|
||||
servers = append(servers, MCPServer{
|
||||
SessionID: sessionID,
|
||||
Name: name,
|
||||
Command: command,
|
||||
ArgsJSON: argsJSON,
|
||||
EnvJSON: envJSON,
|
||||
Command: server.Command,
|
||||
ArgsJSON: string(argsJSON),
|
||||
EnvJSON: string(envJSON),
|
||||
})
|
||||
}
|
||||
return servers, nil
|
||||
|
||||
@@ -31,7 +31,7 @@ type ConversationStore interface {
|
||||
GetToolCallByID(ctx context.Context, toolID string) (*ConversationEvent, error)
|
||||
MarkToolCallCompleted(ctx context.Context, toolID string, sessionID string) error
|
||||
CorrelateApproval(ctx context.Context, sessionID string, toolName string, approvalID string) error
|
||||
LinkConversationEventToApprovalUsingToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error
|
||||
CorrelateApprovalByToolID(ctx context.Context, sessionID string, toolID string, approvalID string) error
|
||||
UpdateApprovalStatus(ctx context.Context, approvalID string, status string) error
|
||||
|
||||
// MCP server operations
|
||||
@@ -201,7 +201,6 @@ type Approval struct {
|
||||
ID string `json:"id"`
|
||||
RunID string `json:"run_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
ToolUseID *string `json:"tool_use_id,omitempty"`
|
||||
Status ApprovalStatus `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
RespondedAt *time.Time `json:"responded_at,omitempty"`
|
||||
|
||||
@@ -46,16 +46,13 @@ export const launchCommand = async (query: string, options: LaunchOptions = {})
|
||||
|
||||
try {
|
||||
// Build MCP config (approvals enabled by default unless explicitly disabled)
|
||||
// Phase 6: Using HTTP MCP endpoint instead of stdio
|
||||
const daemonPort = process.env.HUMANLAYER_DAEMON_HTTP_PORT || '7777'
|
||||
const mcpConfig =
|
||||
options.approvals !== false
|
||||
? {
|
||||
mcpServers: {
|
||||
codelayer: {
|
||||
type: 'http',
|
||||
url: `http://localhost:${daemonPort}/api/v1/mcp`,
|
||||
// Session ID will be added as header by Claude Code
|
||||
approvals: {
|
||||
command: 'npx',
|
||||
args: ['humanlayer', 'mcp', 'claude_approvals'],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -69,7 +66,7 @@ export const launchCommand = async (query: string, options: LaunchOptions = {})
|
||||
working_dir: options.workingDir || process.cwd(),
|
||||
max_turns: options.maxTurns,
|
||||
mcp_config: mcpConfig,
|
||||
permission_prompt_tool: mcpConfig ? 'mcp__codelayer__request_approval' : undefined,
|
||||
permission_prompt_tool: mcpConfig ? 'mcp__approvals__request_permission' : undefined,
|
||||
dangerously_skip_permissions: options.dangerouslySkipPermissions,
|
||||
dangerously_skip_permissions_timeout: options.dangerouslySkipPermissionsTimeout
|
||||
? parseInt(options.dangerouslySkipPermissionsTimeout) * 60 * 1000
|
||||
|
||||
@@ -10,6 +10,7 @@ import { launchCommand } from './commands/launch.js'
|
||||
import { alertCommand } from './commands/alert.js'
|
||||
import { thoughtsCommand } from './commands/thoughts.js'
|
||||
import { joinWaitlistCommand } from './commands/joinWaitlist.js'
|
||||
import { startDefaultMCPServer, startClaudeApprovalsMCPServer } from './mcp.js'
|
||||
import {
|
||||
getDefaultConfigPath,
|
||||
resolveFullConfig,
|
||||
@@ -66,7 +67,7 @@ async function authenticate(printSelectedProject: boolean = false) {
|
||||
|
||||
program.name('humanlayer').description('HumanLayer, but on your command-line.').version(VERSION)
|
||||
|
||||
const UNPROTECTED_COMMANDS = ['config', 'login', 'thoughts', 'join-waitlist', 'launch']
|
||||
const UNPROTECTED_COMMANDS = ['config', 'login', 'thoughts', 'join-waitlist', 'launch', 'mcp']
|
||||
|
||||
program.hook('preAction', async (thisCmd, actionCmd) => {
|
||||
// Get the full command path by traversing up the command hierarchy
|
||||
@@ -171,6 +172,36 @@ program
|
||||
.option('--daemon-socket <path>', 'Path to daemon socket')
|
||||
.action(alertCommand)
|
||||
|
||||
const mcpCommand = program.command('mcp').description('MCP server functionality')
|
||||
|
||||
mcpCommand
|
||||
.command('serve')
|
||||
.description('Start the default MCP server for contact_human functionality')
|
||||
.action(startDefaultMCPServer)
|
||||
|
||||
mcpCommand
|
||||
.command('claude_approvals')
|
||||
.description('Start the Claude approvals MCP server for permission requests')
|
||||
.action(startClaudeApprovalsMCPServer)
|
||||
|
||||
mcpCommand
|
||||
.command('wrapper')
|
||||
.description('Wrap an existing MCP server with human approval functionality (not implemented yet)')
|
||||
.action(() => {
|
||||
console.log('MCP wrapper functionality is not implemented yet.')
|
||||
console.log('This will allow wrapping any existing MCP server with human approval.')
|
||||
process.exit(1)
|
||||
})
|
||||
|
||||
mcpCommand
|
||||
.command('inspector')
|
||||
.description('Run MCP inspector for debugging MCP servers')
|
||||
.argument('[command]', 'MCP server command to inspect', 'serve')
|
||||
.action(command => {
|
||||
const args = ['@modelcontextprotocol/inspector', 'node', 'dist/index.js', 'mcp', command]
|
||||
spawn('npx', args, { stdio: 'inherit', cwd: process.cwd() })
|
||||
})
|
||||
|
||||
// Add thoughts command
|
||||
thoughtsCommand(program)
|
||||
|
||||
|
||||
274
hlyr/src/mcp.ts
Normal file
274
hlyr/src/mcp.ts
Normal file
@@ -0,0 +1,274 @@
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
||||
import {
|
||||
CallToolRequestSchema,
|
||||
ErrorCode,
|
||||
ListToolsRequestSchema,
|
||||
McpError,
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import { humanlayer } from '@humanlayer/sdk'
|
||||
import { resolveFullConfig } from './config.js'
|
||||
import { DaemonClient } from './daemonClient.js'
|
||||
import { logger } from './mcpLogger.js'
|
||||
|
||||
function validateAuth(): void {
|
||||
const config = resolveFullConfig({})
|
||||
|
||||
if (!config.api_key) {
|
||||
console.error('Error: No HumanLayer API token found.')
|
||||
console.error('Please set HUMANLAYER_API_KEY environment variable or run `humanlayer login`')
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the default MCP server that provides contact_human functionality
|
||||
* Uses web UI by default when no contact channel is configured
|
||||
*/
|
||||
export async function startDefaultMCPServer() {
|
||||
validateAuth()
|
||||
|
||||
const server = new Server(
|
||||
{
|
||||
name: 'humanlayer-standalone',
|
||||
version: '1.0.0',
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
tools: {},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
const resolvedConfig = resolveFullConfig({})
|
||||
|
||||
const hl = humanlayer({
|
||||
apiKey: resolvedConfig.api_key,
|
||||
...(resolvedConfig.api_base_url && { apiBaseUrl: resolvedConfig.api_base_url }),
|
||||
...(resolvedConfig.run_id && { runId: resolvedConfig.run_id }),
|
||||
...(Object.keys(resolvedConfig.contact_channel).length > 0 && {
|
||||
contactChannel: resolvedConfig.contact_channel,
|
||||
}),
|
||||
})
|
||||
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||
return {
|
||||
tools: [
|
||||
{
|
||||
name: 'contact_human',
|
||||
description: 'Contact a human for assistance',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
message: { type: 'string' },
|
||||
},
|
||||
required: ['message'],
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
})
|
||||
|
||||
server.setRequestHandler(CallToolRequestSchema, async request => {
|
||||
if (request.params.name === 'contact_human') {
|
||||
const response = await hl.fetchHumanResponse({
|
||||
spec: {
|
||||
msg: request.params.arguments?.message,
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: response,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name')
|
||||
})
|
||||
|
||||
const transport = new StdioServerTransport()
|
||||
await server.connect(transport)
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the Claude approvals MCP server that provides request_permission functionality
|
||||
* Returns responses in the format required by Claude Code SDK
|
||||
*
|
||||
* This now uses local approvals through the daemon instead of HumanLayer API
|
||||
*/
|
||||
export async function startClaudeApprovalsMCPServer() {
|
||||
// No auth validation needed - uses local daemon
|
||||
logger.info('Starting Claude approvals MCP server')
|
||||
logger.info('Environment variables', {
|
||||
HUMANLAYER_DAEMON_SOCKET: process.env.HUMANLAYER_DAEMON_SOCKET,
|
||||
HUMANLAYER_RUN_ID: process.env.HUMANLAYER_RUN_ID,
|
||||
})
|
||||
|
||||
const server = new Server(
|
||||
{
|
||||
name: 'humanlayer-claude-local-approvals',
|
||||
version: '1.0.0',
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
tools: {},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// Create daemon client with socket path from environment or config
|
||||
// The daemon sets HUMANLAYER_DAEMON_SOCKET for MCP servers it launches
|
||||
const resolvedConfig = resolveFullConfig({})
|
||||
const socketPath = process.env.HUMANLAYER_DAEMON_SOCKET || resolvedConfig.daemon_socket
|
||||
logger.info('Creating daemon client', { socketPath })
|
||||
const daemonClient = new DaemonClient(socketPath)
|
||||
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||
logger.info('ListTools request received')
|
||||
const tools = [
|
||||
{
|
||||
name: 'request_permission',
|
||||
description: 'Request permission to perform an action',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
tool_name: { type: 'string' },
|
||||
input: { type: 'object' },
|
||||
},
|
||||
required: ['tool_name', 'input'],
|
||||
},
|
||||
},
|
||||
]
|
||||
logger.info('Returning tools', { tools })
|
||||
return { tools }
|
||||
})
|
||||
|
||||
server.setRequestHandler(CallToolRequestSchema, async request => {
|
||||
logger.debug('Received tool call request', { name: request.params.name })
|
||||
|
||||
if (request.params.name === 'request_permission') {
|
||||
const toolName: string | undefined = request.params.arguments?.tool_name
|
||||
|
||||
if (!toolName) {
|
||||
logger.error('Invalid tool name in request_permission', request.params.arguments)
|
||||
throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name requesting permissions')
|
||||
}
|
||||
|
||||
const input: Record<string, unknown> = request.params.arguments?.input || {}
|
||||
|
||||
// Get run ID from environment (set by Claude Code)
|
||||
const runId = process.env.HUMANLAYER_RUN_ID
|
||||
if (!runId) {
|
||||
logger.error('HUMANLAYER_RUN_ID not set in environment')
|
||||
throw new McpError(ErrorCode.InternalError, 'HUMANLAYER_RUN_ID not set')
|
||||
}
|
||||
|
||||
logger.info('Processing approval request', { runId, toolName })
|
||||
|
||||
try {
|
||||
// Connect to daemon
|
||||
logger.debug('Connecting to daemon...')
|
||||
await daemonClient.connect()
|
||||
logger.debug('Connected to daemon')
|
||||
|
||||
// Create approval request
|
||||
logger.debug('Creating approval request...', { runId, toolName })
|
||||
const createResponse = await daemonClient.createApproval(runId, toolName, input)
|
||||
const approvalId = createResponse.approval_id
|
||||
logger.info('Created approval', { approvalId })
|
||||
|
||||
// Poll for approval status
|
||||
let approved = false
|
||||
let comment = ''
|
||||
let polling = true
|
||||
|
||||
while (polling) {
|
||||
try {
|
||||
// Get the specific approval by ID
|
||||
logger.debug('Fetching approval status...', { approvalId })
|
||||
const approval = (await daemonClient.getApproval(approvalId)) as {
|
||||
id: string
|
||||
status: string
|
||||
comment?: string
|
||||
}
|
||||
|
||||
logger.debug('Approval status', { status: approval.status })
|
||||
|
||||
if (approval.status !== 'pending') {
|
||||
// Approval has been resolved
|
||||
approved = approval.status === 'approved'
|
||||
comment = approval.comment || ''
|
||||
polling = false
|
||||
logger.info('Approval resolved', {
|
||||
approvalId,
|
||||
status: approval.status,
|
||||
approved,
|
||||
})
|
||||
} else {
|
||||
// Still pending, wait and poll again
|
||||
logger.debug('Approval still pending, polling again...')
|
||||
await new Promise(resolve => setTimeout(resolve, 1000))
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to get approval status', { error, approvalId })
|
||||
// Re-throw the error since this is a critical failure
|
||||
throw new McpError(ErrorCode.InternalError, 'Failed to get approval status')
|
||||
}
|
||||
}
|
||||
|
||||
if (!approved) {
|
||||
logger.info('Approval denied', { approvalId, comment })
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
behavior: 'deny',
|
||||
message: comment || 'Request denied by human reviewer',
|
||||
}),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
logger.info('Approval granted', { approvalId })
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
behavior: 'allow',
|
||||
updatedInput: input,
|
||||
}),
|
||||
},
|
||||
],
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to process approval', error)
|
||||
throw new McpError(
|
||||
ErrorCode.InternalError,
|
||||
`Failed to process approval: ${error instanceof Error ? error.message : String(error)}`,
|
||||
)
|
||||
} finally {
|
||||
logger.debug('Closing daemon connection')
|
||||
daemonClient.close()
|
||||
}
|
||||
}
|
||||
|
||||
throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name')
|
||||
})
|
||||
|
||||
const transport = new StdioServerTransport()
|
||||
|
||||
try {
|
||||
await server.connect(transport)
|
||||
logger.info('MCP server connected and ready')
|
||||
} catch (error) {
|
||||
logger.error('Failed to start MCP server', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@@ -95,16 +95,6 @@ export function useApprovalsWithSubscription(sessionId?: string): UseApprovalsRe
|
||||
onEvent: event => {
|
||||
if (!isSubscribed) return
|
||||
|
||||
// Phase 7: Debug logging to verify tool_use_id flows through
|
||||
if (event.type === 'new_approval' || event.type === 'approval_resolved') {
|
||||
console.debug('Approval event with tool_use_id:', {
|
||||
type: event.type,
|
||||
approval_id: event.data?.approval_id,
|
||||
tool_use_id: event.data?.tool_use_id,
|
||||
data: event.data,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle different event types
|
||||
switch (event.type) {
|
||||
case 'new_approval':
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { create } from 'zustand'
|
||||
import { daemonClient } from '@/lib/daemon'
|
||||
import type { LaunchSessionRequest } from '@/lib/daemon/types'
|
||||
import { getDaemonUrl } from '@/lib/daemon/http-config'
|
||||
import { useHotkeysContext } from 'react-hotkeys-hook'
|
||||
import { SessionTableHotkeysScope } from '@/components/internal/SessionTable'
|
||||
import { exists } from '@tauri-apps/plugin-fs'
|
||||
@@ -142,13 +141,11 @@ export const useSessionLauncher = create<LauncherState>((set, get) => ({
|
||||
set({ isLaunching: true, error: undefined })
|
||||
|
||||
// Build MCP config (approvals enabled by default)
|
||||
// Use HTTP-based MCP server built into the daemon
|
||||
const daemonUrl = await getDaemonUrl()
|
||||
const mcpConfig = {
|
||||
mcpServers: {
|
||||
approvals: {
|
||||
type: 'http',
|
||||
url: `${daemonUrl}/api/v1/mcp`,
|
||||
command: 'npx',
|
||||
args: ['humanlayer', 'mcp', 'claude_approvals'],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -160,7 +157,7 @@ export const useSessionLauncher = create<LauncherState>((set, get) => ({
|
||||
model: config.model || undefined,
|
||||
max_turns: config.maxTurns || undefined,
|
||||
mcp_config: mcpConfig,
|
||||
permission_prompt_tool: 'mcp__approvals__request_approval',
|
||||
permission_prompt_tool: 'mcp__approvals__request_permission',
|
||||
}
|
||||
|
||||
const response = await daemonClient.launchSession(request)
|
||||
|
||||
Reference in New Issue
Block a user