Implement local-only approvals, removing HumanLayer API dependencies (#261)

* Implement local-only approvals, removing HumanLayer API dependencies

Major refactoring to simplify the approval system by making it entirely local:

- Add approvals table to SQLite with CRUD operations
- Replace complex HumanLayer API types with simple local types
- Rewrite approval manager to only handle local SQLite operations
- Remove API polling, correlation, and remote approval code
- Update RPC handlers to support local approvals with backward compatibility
- Update MCP integration to create and poll local approvals
- Update TUI to work with new simplified approval format
- Fix all tests including integration tests

This removes ~1800 lines of API integration code while maintaining all
functionality. Approvals are now faster, simpler, and have no external
dependencies.

* Add getApproval RPC endpoint for local approval polling

- Add GetApproval method to Manager interface
- Implement GetApproval in manager to retrieve specific approvals by ID
- Add HandleGetApproval RPC handler to expose endpoint
- Add migration 4 to create approvals table for existing databases

This allows the MCP server to poll for specific approval status
instead of fetching all pending approvals repeatedly.

* formatting

* Update MCP server to use getApproval endpoint

- Add getApproval method to daemonClient with proper TypeScript types
- Refactor MCP approval polling to use single approval lookup instead of fetching all
- Improve Event interface documentation with actual field names from daemon
- Add MCP-specific logger for better debugging and monitoring
- Simplify polling logic for cleaner, more efficient approval status checks

* Add test tooling and documentation for local approvals

- Add comprehensive test script for automated and interactive approval testing
- Create test_local_approvals.md documentation with usage examples
- Add .npmignore to exclude test files from npm package
- Update README with reference to test documentation
- Test script features auto-exit on session completion and random content generation
- Fix event field names to match actual daemon event data structure

* formatting
This commit is contained in:
Allison Durham
2025-07-01 16:35:36 -07:00
committed by GitHub
parent a54f57886b
commit 080277b05a
27 changed files with 2056 additions and 2671 deletions

View File

@@ -61,7 +61,7 @@ status:
# Generate mocks
mocks:
mockgen -source=session/types.go -destination=session/mock_session.go -package=session SessionManager
mockgen -source=approval/types.go -destination=approval/mock_approval.go -package=approval Manager,Store,APIClient
mockgen -source=approval/types.go -destination=approval/mock_approval.go -package=approval Manager
mockgen -source=client/types.go -destination=client/mock_client.go -package=client Client,Factory
mockgen -source=bus/types.go -destination=bus/mock_bus.go -package=bus EventBus
mockgen -source=store/store.go -destination=store/mock_store.go -package=store ConversationStore

View File

@@ -1,157 +0,0 @@
package approval
import (
"context"
"fmt"
"testing"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
"go.uber.org/mock/gomock"
)
func TestPoller_CorrelateApproval(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tests := []struct {
name string
functionCall humanlayer.FunctionCall
setupStore func(mockStore *store.MockConversationStore)
expectCorrelate bool
expectStatusUpdate bool
}{
{
name: "successful_correlation",
functionCall: humanlayer.FunctionCall{
CallID: "fc-123",
RunID: "run-456",
Spec: humanlayer.FunctionCallSpec{
Fn: "dangerous_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// Expect GetSessionByRunID to find the session
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "run-456").
Return(&store.Session{
ID: "sess-789",
RunID: "run-456",
}, nil)
// Return a pending tool call
mockStore.EXPECT().
GetPendingToolCall(gomock.Any(), "sess-789", "dangerous_function").
Return(&store.ConversationEvent{
ID: 1,
SessionID: "sess-789",
ToolName: "dangerous_function",
}, nil)
// Expect correlation
mockStore.EXPECT().
CorrelateApproval(gomock.Any(), "sess-789", "dangerous_function", "fc-123").
Return(nil)
// Expect status update to waiting_input
mockStore.EXPECT().
UpdateSession(gomock.Any(), "sess-789", gomock.Any()).
DoAndReturn(func(ctx context.Context, sessionID string, update store.SessionUpdate) error {
if update.Status == nil || *update.Status != store.SessionStatusWaitingInput {
t.Errorf("expected status update to waiting_input, got %v", update.Status)
}
return nil
})
},
expectCorrelate: true,
expectStatusUpdate: true,
},
{
name: "no_matching_session",
functionCall: humanlayer.FunctionCall{
CallID: "fc-999",
RunID: "unknown-run",
Spec: humanlayer.FunctionCallSpec{
Fn: "some_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// GetSessionByRunID returns nil (no matching session)
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "unknown-run").
Return(nil, nil)
// No further calls should happen
},
expectCorrelate: false,
expectStatusUpdate: false,
},
{
name: "nil_tool_call_returned",
functionCall: humanlayer.FunctionCall{
CallID: "fc-nil",
RunID: "run-nil",
Spec: humanlayer.FunctionCallSpec{
Fn: "nil_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// GetSessionByRunID returns a matching session
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "run-nil").
Return(&store.Session{
ID: "sess-nil",
RunID: "run-nil",
}, nil)
// Return nil tool call (no error but no result)
mockStore.EXPECT().
GetPendingToolCall(gomock.Any(), "sess-nil", "nil_function").
Return(nil, nil)
// No correlation or status update should happen
},
expectCorrelate: false,
expectStatusUpdate: false,
},
{
name: "get_session_error",
functionCall: humanlayer.FunctionCall{
CallID: "fc-error",
RunID: "run-error",
Spec: humanlayer.FunctionCallSpec{
Fn: "error_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// GetSessionByRunID returns an error
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "run-error").
Return(nil, fmt.Errorf("database error"))
// No further calls should happen
},
expectCorrelate: false,
expectStatusUpdate: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock store
mockStore := store.NewMockConversationStore(ctrl)
tt.setupStore(mockStore)
// Create poller with mock store
poller := &Poller{
conversationStore: mockStore,
}
// Call correlateApproval
ctx := context.Background()
poller.correlateApproval(ctx, tt.functionCall)
// The expectations are verified by gomock
})
}
}

View File

@@ -1,294 +0,0 @@
package approval
import (
"fmt"
"sync"
"time"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
// MemoryStore is an in-memory implementation of Store
type MemoryStore struct {
mu sync.RWMutex
functionCalls map[string]*humanlayer.FunctionCall // indexed by call_id
humanContacts map[string]*humanlayer.HumanContact // indexed by call_id
byRunID map[string][]string // run_id -> []call_id
}
// NewMemoryStore creates a new in-memory store
func NewMemoryStore() *MemoryStore {
return &MemoryStore{
functionCalls: make(map[string]*humanlayer.FunctionCall),
humanContacts: make(map[string]*humanlayer.HumanContact),
byRunID: make(map[string][]string),
}
}
// StoreFunctionCall stores a function call approval request
func (s *MemoryStore) StoreFunctionCall(fc humanlayer.FunctionCall) error {
s.mu.Lock()
defer s.mu.Unlock()
// Store by call_id
s.functionCalls[fc.CallID] = &fc
// Index by run_id if present
if fc.RunID != "" {
s.addToRunIndex(fc.RunID, fc.CallID)
}
return nil
}
// StoreHumanContact stores a human contact request
func (s *MemoryStore) StoreHumanContact(hc humanlayer.HumanContact) error {
s.mu.Lock()
defer s.mu.Unlock()
// Store by call_id
s.humanContacts[hc.CallID] = &hc
// Index by run_id if present
if hc.RunID != "" {
s.addToRunIndex(hc.RunID, hc.CallID)
}
return nil
}
// GetFunctionCall retrieves a function call by ID
func (s *MemoryStore) GetFunctionCall(callID string) (*humanlayer.FunctionCall, error) {
s.mu.RLock()
defer s.mu.RUnlock()
fc, ok := s.functionCalls[callID]
if !ok {
return nil, fmt.Errorf("function call not found: %s", callID)
}
return fc, nil
}
// GetHumanContact retrieves a human contact by ID
func (s *MemoryStore) GetHumanContact(callID string) (*humanlayer.HumanContact, error) {
s.mu.RLock()
defer s.mu.RUnlock()
hc, ok := s.humanContacts[callID]
if !ok {
return nil, fmt.Errorf("human contact not found: %s", callID)
}
return hc, nil
}
// GetAllPending returns all pending approvals
func (s *MemoryStore) GetAllPending() ([]PendingApproval, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var approvals []PendingApproval
// Add function calls
for _, fc := range s.functionCalls {
// Only include if not yet responded
if fc.Status == nil || fc.Status.RespondedAt == nil {
approvals = append(approvals, PendingApproval{
Type: "function_call",
FunctionCall: fc,
})
}
}
// Add human contacts
for _, hc := range s.humanContacts {
// Only include if not yet responded
if hc.Status == nil || hc.Status.RespondedAt == nil {
approvals = append(approvals, PendingApproval{
Type: "human_contact",
HumanContact: hc,
})
}
}
return approvals, nil
}
// GetPendingByRunID returns pending approvals for a specific run_id
func (s *MemoryStore) GetPendingByRunID(runID string) ([]PendingApproval, error) {
s.mu.RLock()
defer s.mu.RUnlock()
callIDs, ok := s.byRunID[runID]
if !ok {
return []PendingApproval{}, nil
}
var approvals []PendingApproval
for _, callID := range callIDs {
// Check function calls
if fc, ok := s.functionCalls[callID]; ok {
if fc.Status == nil || fc.Status.RespondedAt == nil {
approvals = append(approvals, PendingApproval{
Type: "function_call",
FunctionCall: fc,
})
}
}
// Check human contacts
if hc, ok := s.humanContacts[callID]; ok {
if hc.Status == nil || hc.Status.RespondedAt == nil {
approvals = append(approvals, PendingApproval{
Type: "human_contact",
HumanContact: hc,
})
}
}
}
return approvals, nil
}
// MarkFunctionCallResponded marks a function call as responded
func (s *MemoryStore) MarkFunctionCallResponded(callID string) error {
s.mu.Lock()
defer s.mu.Unlock()
fc, ok := s.functionCalls[callID]
if !ok {
return fmt.Errorf("function call not found: %s", callID)
}
// Update status to mark as responded
if fc.Status == nil {
fc.Status = &humanlayer.FunctionCallStatus{}
}
now := humanlayer.CustomTime{Time: time.Now()}
fc.Status.RespondedAt = &now
return nil
}
// MarkHumanContactResponded marks a human contact as responded
func (s *MemoryStore) MarkHumanContactResponded(callID string) error {
s.mu.Lock()
defer s.mu.Unlock()
hc, ok := s.humanContacts[callID]
if !ok {
return fmt.Errorf("human contact not found: %s", callID)
}
// Update status to mark as responded
if hc.Status == nil {
hc.Status = &humanlayer.HumanContactStatus{}
}
now := humanlayer.CustomTime{Time: time.Now()}
hc.Status.RespondedAt = &now
return nil
}
// GetAllCachedFunctionCalls returns all cached function calls
func (s *MemoryStore) GetAllCachedFunctionCalls() ([]humanlayer.FunctionCall, error) {
s.mu.RLock()
defer s.mu.RUnlock()
calls := make([]humanlayer.FunctionCall, 0, len(s.functionCalls))
for _, fc := range s.functionCalls {
calls = append(calls, *fc)
}
return calls, nil
}
// GetAllCachedHumanContacts returns all cached human contacts
func (s *MemoryStore) GetAllCachedHumanContacts() ([]humanlayer.HumanContact, error) {
s.mu.RLock()
defer s.mu.RUnlock()
contacts := make([]humanlayer.HumanContact, 0, len(s.humanContacts))
for _, hc := range s.humanContacts {
contacts = append(contacts, *hc)
}
return contacts, nil
}
// RemoveFunctionCall removes a function call from the store
func (s *MemoryStore) RemoveFunctionCall(callID string) error {
s.mu.Lock()
defer s.mu.Unlock()
fc, ok := s.functionCalls[callID]
if !ok {
return nil // Already removed
}
// Remove from main map
delete(s.functionCalls, callID)
// Remove from run_id index
if fc.RunID != "" {
s.removeFromRunIndex(fc.RunID, callID)
}
return nil
}
// RemoveHumanContact removes a human contact from the store
func (s *MemoryStore) RemoveHumanContact(callID string) error {
s.mu.Lock()
defer s.mu.Unlock()
hc, ok := s.humanContacts[callID]
if !ok {
return nil // Already removed
}
// Remove from main map
delete(s.humanContacts, callID)
// Remove from run_id index
if hc.RunID != "" {
s.removeFromRunIndex(hc.RunID, callID)
}
return nil
}
// addToRunIndex adds a call_id to the run_id index
func (s *MemoryStore) addToRunIndex(runID, callID string) {
if _, exists := s.byRunID[runID]; !exists {
s.byRunID[runID] = []string{}
}
// Check if callID already exists to avoid duplicates
for _, id := range s.byRunID[runID] {
if id == callID {
return
}
}
s.byRunID[runID] = append(s.byRunID[runID], callID)
}
// removeFromRunIndex removes a call_id from the run_id index
func (s *MemoryStore) removeFromRunIndex(runID, callID string) {
ids, exists := s.byRunID[runID]
if !exists {
return
}
// Filter out the callID
newIDs := make([]string, 0, len(ids))
for _, id := range ids {
if id != callID {
newIDs = append(newIDs, id)
}
}
if len(newIDs) == 0 {
delete(s.byRunID, runID)
} else {
s.byRunID[runID] = newIDs
}
}

View File

@@ -1,311 +0,0 @@
package approval
import (
"fmt"
"testing"
"time"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
func TestMemoryStore_StoreFunctionCall(t *testing.T) {
store := NewMemoryStore()
fc := humanlayer.FunctionCall{
CallID: "test-call-1",
RunID: "test-run-1",
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
},
}
// Store function call
err := store.StoreFunctionCall(fc)
if err != nil {
t.Fatalf("failed to store function call: %v", err)
}
// Retrieve it
retrieved, err := store.GetFunctionCall("test-call-1")
if err != nil {
t.Fatalf("failed to get function call: %v", err)
}
if retrieved.CallID != fc.CallID {
t.Errorf("expected call_id %s, got %s", fc.CallID, retrieved.CallID)
}
if retrieved.RunID != fc.RunID {
t.Errorf("expected run_id %s, got %s", fc.RunID, retrieved.RunID)
}
}
func TestMemoryStore_StoreHumanContact(t *testing.T) {
store := NewMemoryStore()
hc := humanlayer.HumanContact{
CallID: "test-contact-1",
RunID: "test-run-1",
Spec: humanlayer.HumanContactSpec{
Msg: "Test message",
},
}
// Store human contact
err := store.StoreHumanContact(hc)
if err != nil {
t.Fatalf("failed to store human contact: %v", err)
}
// Retrieve it
retrieved, err := store.GetHumanContact("test-contact-1")
if err != nil {
t.Fatalf("failed to get human contact: %v", err)
}
if retrieved.CallID != hc.CallID {
t.Errorf("expected call_id %s, got %s", hc.CallID, retrieved.CallID)
}
if retrieved.RunID != hc.RunID {
t.Errorf("expected run_id %s, got %s", hc.RunID, retrieved.RunID)
}
}
func TestMemoryStore_GetAllPending(t *testing.T) {
store := NewMemoryStore()
// Store some approvals
fc1 := humanlayer.FunctionCall{
CallID: "fc-1",
RunID: "run-1",
Spec: humanlayer.FunctionCallSpec{Fn: "func1"},
}
fc2 := humanlayer.FunctionCall{
CallID: "fc-2",
RunID: "run-2",
Spec: humanlayer.FunctionCallSpec{Fn: "func2"},
Status: &humanlayer.FunctionCallStatus{
RespondedAt: &humanlayer.CustomTime{Time: time.Now()},
},
}
hc1 := humanlayer.HumanContact{
CallID: "hc-1",
RunID: "run-1",
Spec: humanlayer.HumanContactSpec{Msg: "msg1"},
}
if err := store.StoreFunctionCall(fc1); err != nil {
t.Fatalf("failed to store function call fc1: %v", err)
}
if err := store.StoreFunctionCall(fc2); err != nil {
t.Fatalf("failed to store function call fc2: %v", err)
}
if err := store.StoreHumanContact(hc1); err != nil {
t.Fatalf("failed to store human contact hc1: %v", err)
}
// Get all pending
pending, err := store.GetAllPending()
if err != nil {
t.Fatalf("failed to get pending: %v", err)
}
// Should have 2 pending (fc1 and hc1, not fc2)
if len(pending) != 2 {
t.Errorf("expected 2 pending approvals, got %d", len(pending))
}
// Verify the right ones are included
foundFC1 := false
foundHC1 := false
for _, p := range pending {
if p.Type == "function_call" && p.FunctionCall.CallID == "fc-1" {
foundFC1 = true
}
if p.Type == "human_contact" && p.HumanContact.CallID == "hc-1" {
foundHC1 = true
}
}
if !foundFC1 {
t.Error("expected to find fc-1 in pending approvals")
}
if !foundHC1 {
t.Error("expected to find hc-1 in pending approvals")
}
}
func TestMemoryStore_GetPendingByRunID(t *testing.T) {
store := NewMemoryStore()
// Store approvals for different runs
fc1 := humanlayer.FunctionCall{
CallID: "fc-1",
RunID: "run-1",
Spec: humanlayer.FunctionCallSpec{Fn: "func1"},
}
fc2 := humanlayer.FunctionCall{
CallID: "fc-2",
RunID: "run-2",
Spec: humanlayer.FunctionCallSpec{Fn: "func2"},
}
hc1 := humanlayer.HumanContact{
CallID: "hc-1",
RunID: "run-1",
Spec: humanlayer.HumanContactSpec{Msg: "msg1"},
}
if err := store.StoreFunctionCall(fc1); err != nil {
t.Fatalf("failed to store function call fc1: %v", err)
}
if err := store.StoreFunctionCall(fc2); err != nil {
t.Fatalf("failed to store function call fc2: %v", err)
}
if err := store.StoreHumanContact(hc1); err != nil {
t.Fatalf("failed to store human contact hc1: %v", err)
}
// Get pending for run-1
pending, err := store.GetPendingByRunID("run-1")
if err != nil {
t.Fatalf("failed to get pending by run_id: %v", err)
}
// Should have 2 approvals for run-1
if len(pending) != 2 {
t.Errorf("expected 2 pending approvals for run-1, got %d", len(pending))
}
// Get pending for run-2
pending, err = store.GetPendingByRunID("run-2")
if err != nil {
t.Fatalf("failed to get pending by run_id: %v", err)
}
// Should have 1 approval for run-2
if len(pending) != 1 {
t.Errorf("expected 1 pending approval for run-2, got %d", len(pending))
}
// Get pending for non-existent run
pending, err = store.GetPendingByRunID("run-999")
if err != nil {
t.Fatalf("failed to get pending by run_id: %v", err)
}
// Should have 0 approvals
if len(pending) != 0 {
t.Errorf("expected 0 pending approvals for non-existent run, got %d", len(pending))
}
}
func TestMemoryStore_MarkResponded(t *testing.T) {
store := NewMemoryStore()
// Store approvals
fc := humanlayer.FunctionCall{
CallID: "fc-1",
RunID: "run-1",
Spec: humanlayer.FunctionCallSpec{Fn: "func1"},
}
hc := humanlayer.HumanContact{
CallID: "hc-1",
RunID: "run-1",
Spec: humanlayer.HumanContactSpec{Msg: "msg1"},
}
if err := store.StoreFunctionCall(fc); err != nil {
t.Fatalf("failed to store function call: %v", err)
}
if err := store.StoreHumanContact(hc); err != nil {
t.Fatalf("failed to store human contact: %v", err)
}
// Mark function call as responded
err := store.MarkFunctionCallResponded("fc-1")
if err != nil {
t.Fatalf("failed to mark function call responded: %v", err)
}
// Mark human contact as responded
err = store.MarkHumanContactResponded("hc-1")
if err != nil {
t.Fatalf("failed to mark human contact responded: %v", err)
}
// Verify they're no longer pending
pending, err := store.GetAllPending()
if err != nil {
t.Fatalf("failed to get pending: %v", err)
}
if len(pending) != 0 {
t.Errorf("expected 0 pending approvals after marking responded, got %d", len(pending))
}
}
func TestMemoryStore_ConcurrentAccess(t *testing.T) {
store := NewMemoryStore()
done := make(chan bool)
// Concurrent writes
go func() {
var errors []error
for i := 0; i < 100; i++ {
fc := humanlayer.FunctionCall{
CallID: fmt.Sprintf("fc-%d", i),
RunID: fmt.Sprintf("run-%d", i%10),
Spec: humanlayer.FunctionCallSpec{Fn: "func"},
}
if err := store.StoreFunctionCall(fc); err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
t.Errorf("StoreFunctionCall errors during concurrent access: %v", errors)
}
done <- true
}()
go func() {
var errors []error
for i := 0; i < 100; i++ {
hc := humanlayer.HumanContact{
CallID: fmt.Sprintf("hc-%d", i),
RunID: fmt.Sprintf("run-%d", i%10),
Spec: humanlayer.HumanContactSpec{Msg: "msg"},
}
if err := store.StoreHumanContact(hc); err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
t.Errorf("StoreHumanContact errors during concurrent access: %v", errors)
}
done <- true
}()
// Concurrent reads
go func() {
for i := 0; i < 100; i++ {
_, _ = store.GetAllPending()
_, _ = store.GetPendingByRunID(fmt.Sprintf("run-%d", i%10))
}
done <- true
}()
// Wait for all goroutines
for i := 0; i < 3; i++ {
<-done
}
// Verify data integrity
pending, err := store.GetAllPending()
if err != nil {
t.Fatalf("failed to get pending: %v", err)
}
// Should have 200 approvals (100 function calls + 100 human contacts)
if len(pending) != 200 {
t.Errorf("expected 200 pending approvals, got %d", len(pending))
}
}

View File

@@ -2,364 +2,233 @@ package approval
import (
"context"
"errors"
"encoding/json"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
// Config holds configuration for the approval manager
type Config struct {
APIKey string
BaseURL string
PollInterval time.Duration
MaxBackoff time.Duration
BackoffFactor float64
// manager manages approvals locally without HumanLayer API
type manager struct {
store store.ConversationStore
eventBus bus.EventBus
}
// DefaultManager is the default implementation of Manager
type DefaultManager struct {
Client APIClient
Store Store
Poller *Poller
EventBus bus.EventBus
ConversationStore store.ConversationStore
// NewManager creates a new local approval manager
func NewManager(store store.ConversationStore, eventBus bus.EventBus) Manager {
return &manager{
store: store,
eventBus: eventBus,
}
}
// NewManager creates a new approval manager
func NewManager(cfg Config, eventBus bus.EventBus, conversationStore store.ConversationStore) (Manager, error) {
// Set defaults
if cfg.PollInterval <= 0 {
cfg.PollInterval = 5 * time.Second
}
if cfg.MaxBackoff <= 0 {
cfg.MaxBackoff = 5 * time.Minute
}
if cfg.BackoffFactor <= 0 {
cfg.BackoffFactor = 2.0
}
// Create HumanLayer client
opts := []humanlayer.ClientOption{
humanlayer.WithAPIKey(cfg.APIKey),
}
// Add base URL if provided
if cfg.BaseURL != "" {
opts = append(opts, humanlayer.WithBaseURL(cfg.BaseURL))
}
client, err := humanlayer.NewClient(opts...)
// CreateApproval creates a new local approval
func (m *manager) CreateApproval(ctx context.Context, runID, toolName string, toolInput json.RawMessage) (string, error) {
// Look up session by run_id
session, err := m.store.GetSessionByRunID(ctx, runID)
if err != nil {
return nil, fmt.Errorf("failed to create HumanLayer client: %w", err)
return "", fmt.Errorf("failed to get session by run_id: %w", err)
}
if session == nil {
return "", fmt.Errorf("session not found for run_id: %s", runID)
}
// Create in-memory store
store := NewMemoryStore()
// Create poller with configured interval
poller := NewPoller(client, store, conversationStore, cfg.PollInterval, eventBus)
poller.maxBackoff = cfg.MaxBackoff
poller.backoffFactor = cfg.BackoffFactor
return &DefaultManager{
Client: client,
Store: store,
Poller: poller,
EventBus: eventBus,
ConversationStore: conversationStore,
}, nil
}
// Start begins polling for approvals
func (m *DefaultManager) Start(ctx context.Context) error {
return m.Poller.Start(ctx)
}
// Stop stops the approval manager
func (m *DefaultManager) Stop() {
m.Poller.Stop()
}
// GetPendingApprovals returns all pending approvals, optionally filtered by session
func (m *DefaultManager) GetPendingApprovals(sessionID string) ([]PendingApproval, error) {
if sessionID != "" && m.ConversationStore != nil {
// Look up the session to get its run_id
ctx := context.Background()
session, err := m.ConversationStore.GetSession(ctx, sessionID)
if err != nil {
slog.Debug("session not found for approval filter", "session_id", sessionID, "error", err)
return []PendingApproval{}, nil
}
// Get approvals by run_id
return m.Store.GetPendingByRunID(session.RunID)
// Create approval
approval := &store.Approval{
ID: "local-" + uuid.New().String(),
RunID: runID,
SessionID: session.ID,
Status: store.ApprovalStatusLocalPending,
CreatedAt: time.Now(),
ToolName: toolName,
ToolInput: toolInput,
}
return m.Store.GetAllPending()
}
// GetPendingApprovalsByRunID returns pending approvals for a specific run_id
func (m *DefaultManager) GetPendingApprovalsByRunID(runID string) ([]PendingApproval, error) {
return m.Store.GetPendingByRunID(runID)
}
// ErrAlreadyResponded is returned when an approval has already been responded to
var ErrAlreadyResponded = errors.New("this approval has already been responded to")
// handleConflictError checks if an error is a conflict and handles it appropriately
func (m *DefaultManager) handleConflictError(ctx context.Context, err error, callID string, approvalType string) error {
var apiErr *humanlayer.APIError
if errors.As(err, &apiErr) && apiErr.IsConflict() {
slog.Info("approval already responded externally",
"type", approvalType,
"call_id", callID,
"error", apiErr.Body)
// Remove from local cache
if approvalType == "function_call" {
if err := m.Store.RemoveFunctionCall(callID); err != nil {
slog.Error("failed to remove function call from cache", "call_id", callID, "error", err)
}
} else {
if err := m.Store.RemoveHumanContact(callID); err != nil {
slog.Error("failed to remove human contact from cache", "call_id", callID, "error", err)
}
}
// Update database status if we have a conversation store
if m.ConversationStore != nil {
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusResolved); err != nil {
slog.Error("failed to update approval status in database", "error", err)
}
}
// Return a specific error to indicate it was already responded
return ErrAlreadyResponded
// Store it
if err := m.store.CreateApproval(ctx, approval); err != nil {
return "", fmt.Errorf("failed to store approval: %w", err)
}
return err
// Try to correlate with the most recent uncorrelated tool call
if err := m.correlateApproval(ctx, approval); err != nil {
// Log but don't fail - correlation is best effort
slog.Warn("failed to correlate approval with tool call",
"error", err,
"approval_id", approval.ID,
"session_id", session.ID)
}
// Publish event for real-time updates
m.publishNewApprovalEvent(approval)
// Update session status to waiting_input
if err := m.updateSessionStatus(ctx, session.ID, store.SessionStatusWaitingInput); err != nil {
slog.Warn("failed to update session status",
"error", err,
"session_id", session.ID)
}
slog.Info("created local approval",
"approval_id", approval.ID,
"session_id", session.ID,
"tool_name", toolName)
return approval.ID, nil
}
// ApproveFunctionCall approves a function call
func (m *DefaultManager) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
// First check if we have this function call
fc, err := m.Store.GetFunctionCall(callID)
// GetPendingApprovals retrieves pending approvals for a session
func (m *manager) GetPendingApprovals(ctx context.Context, sessionID string) ([]*store.Approval, error) {
approvals, err := m.store.GetPendingApprovals(ctx, sessionID)
if err != nil {
return fmt.Errorf("function call not found: %w", err)
return nil, fmt.Errorf("failed to get pending approvals: %w", err)
}
return approvals, nil
}
// GetApproval retrieves a specific approval by ID
func (m *manager) GetApproval(ctx context.Context, id string) (*store.Approval, error) {
approval, err := m.store.GetApproval(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get approval: %w", err)
}
return approval, nil
}
// ApproveToolCall approves a tool call
func (m *manager) ApproveToolCall(ctx context.Context, id string, comment string) error {
// Get the approval first
approval, err := m.store.GetApproval(ctx, id)
if err != nil {
return fmt.Errorf("failed to get approval: %w", err)
}
// Send approval to API
if err := m.Client.ApproveFunctionCall(ctx, callID, comment); err != nil {
// Handle conflict error specially
conflictErr := m.handleConflictError(ctx, err, callID, "function_call")
if errors.Is(conflictErr, ErrAlreadyResponded) {
// Return the already responded error directly
return conflictErr
}
if conflictErr != nil {
return fmt.Errorf("failed to approve function call: %w", conflictErr)
}
// Update approval status
if err := m.store.UpdateApprovalResponse(ctx, id, store.ApprovalStatusLocalApproved, comment); err != nil {
return fmt.Errorf("failed to update approval: %w", err)
}
// Mark as responded in local store
if err := m.Store.MarkFunctionCallResponded(callID); err != nil {
return fmt.Errorf("failed to update local state: %w", err)
}
// Update approval status in database if we have a conversation store
if m.ConversationStore != nil && fc != nil {
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusApproved); err != nil {
slog.Error("failed to update approval status in database", "error", err)
// Don't fail the whole operation for this
}
// Update session status back to running since approval is resolved
if fc.RunID != "" {
// Look up the session by run_id
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
runningStatus := store.SessionStatusRunning
update := store.SessionUpdate{
Status: &runningStatus,
}
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status to running",
"session_id", session.ID,
"error", err)
} else {
slog.Info("updated session status back to running after approval",
"session_id", session.ID,
"approval_id", callID)
// Publish session status change event
if m.EventBus != nil {
m.EventBus.Publish(bus.Event{
Type: bus.EventSessionStatusChanged,
// TODO(4): Can this be a static type later on? Why isn't it currently? Is this because of JSON RPC or a go thing?
Data: map[string]interface{}{
"session_id": session.ID,
"run_id": fc.RunID,
"old_status": string(store.SessionStatusWaitingInput),
"new_status": string(store.SessionStatusRunning),
},
})
}
}
}
}
// Update correlation status in conversation events
if err := m.store.UpdateApprovalStatus(ctx, id, store.ApprovalStatusApproved); err != nil {
slog.Warn("failed to update approval status in conversation events",
"error", err,
"approval_id", id)
}
// Publish event
if m.EventBus != nil && fc != nil {
m.EventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
Data: map[string]interface{}{
"type": "function_call",
"call_id": callID,
"run_id": fc.RunID,
"decision": "approved",
"comment": comment,
},
})
m.publishApprovalResolvedEvent(approval, true, comment)
// Update session status back to running
if err := m.updateSessionStatus(ctx, approval.SessionID, store.SessionStatusRunning); err != nil {
slog.Warn("failed to update session status",
"error", err,
"session_id", approval.SessionID)
}
slog.Info("approved tool call",
"approval_id", id,
"comment", comment)
return nil
}
// DenyToolCall denies a tool call
func (m *manager) DenyToolCall(ctx context.Context, id string, reason string) error {
// Get the approval first
approval, err := m.store.GetApproval(ctx, id)
if err != nil {
return fmt.Errorf("failed to get approval: %w", err)
}
// Update approval status
if err := m.store.UpdateApprovalResponse(ctx, id, store.ApprovalStatusLocalDenied, reason); err != nil {
return fmt.Errorf("failed to update approval: %w", err)
}
// Update correlation status in conversation events
if err := m.store.UpdateApprovalStatus(ctx, id, store.ApprovalStatusDenied); err != nil {
slog.Warn("failed to update approval status in conversation events",
"error", err,
"approval_id", id)
}
// Publish event
m.publishApprovalResolvedEvent(approval, false, reason)
// Update session status back to running
if err := m.updateSessionStatus(ctx, approval.SessionID, store.SessionStatusRunning); err != nil {
slog.Warn("failed to update session status",
"error", err,
"session_id", approval.SessionID)
}
slog.Info("denied tool call",
"approval_id", id,
"reason", reason)
return nil
}
// correlateApproval tries to correlate an approval with a tool call
func (m *manager) correlateApproval(ctx context.Context, approval *store.Approval) error {
// Find the most recent uncorrelated pending tool call
toolCall, err := m.store.GetUncorrelatedPendingToolCall(ctx, approval.SessionID, approval.ToolName)
if err != nil {
return fmt.Errorf("failed to find pending tool call: %w", err)
}
if toolCall == nil {
return fmt.Errorf("no matching tool call found")
}
// Correlate by tool ID
if err := m.store.CorrelateApprovalByToolID(ctx, approval.SessionID, toolCall.ToolID, approval.ID); err != nil {
return fmt.Errorf("failed to correlate approval: %w", err)
}
return nil
}
// DenyFunctionCall denies a function call
func (m *DefaultManager) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
// First check if we have this function call
fc, err := m.Store.GetFunctionCall(callID)
if err != nil {
return fmt.Errorf("function call not found: %w", err)
}
// Send denial to API
if err := m.Client.DenyFunctionCall(ctx, callID, reason); err != nil {
// Handle conflict error specially
conflictErr := m.handleConflictError(ctx, err, callID, "function_call")
if errors.Is(conflictErr, ErrAlreadyResponded) {
// Return the already responded error directly
return conflictErr
}
if conflictErr != nil {
return fmt.Errorf("failed to deny function call: %w", conflictErr)
}
}
// Mark as responded in local store
if err := m.Store.MarkFunctionCallResponded(callID); err != nil {
return fmt.Errorf("failed to update local state: %w", err)
}
// Update approval status in database if we have a conversation store
if m.ConversationStore != nil && fc != nil {
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusDenied); err != nil {
slog.Error("failed to update approval status in database", "error", err)
// Don't fail the whole operation for this
}
// Update session status back to running since approval is resolved (denied)
if fc.RunID != "" {
// Look up the session by run_id
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
runningStatus := store.SessionStatusRunning
update := store.SessionUpdate{
Status: &runningStatus,
}
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status to running",
"session_id", session.ID,
"error", err)
} else {
slog.Info("updated session status back to running after denial",
"session_id", session.ID,
"approval_id", callID)
// Publish session status change event
if m.EventBus != nil {
m.EventBus.Publish(bus.Event{
Type: bus.EventSessionStatusChanged,
Data: map[string]interface{}{
"session_id": session.ID,
"run_id": fc.RunID,
"old_status": string(store.SessionStatusWaitingInput),
"new_status": string(store.SessionStatusRunning),
},
})
}
}
}
}
}
// Publish event
if m.EventBus != nil && fc != nil {
m.EventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
// publishNewApprovalEvent publishes an event when a new approval is created
func (m *manager) publishNewApprovalEvent(approval *store.Approval) {
if m.eventBus != nil {
event := bus.Event{
Type: bus.EventNewApproval,
Timestamp: time.Now(),
Data: map[string]interface{}{
"type": "function_call",
"call_id": callID,
"run_id": fc.RunID,
"decision": "denied",
"reason": reason,
"approval_id": approval.ID,
"session_id": approval.SessionID,
"tool_name": approval.ToolName,
},
})
}
m.eventBus.Publish(event)
}
return nil
}
// RespondToHumanContact sends a response to a human contact request
func (m *DefaultManager) RespondToHumanContact(ctx context.Context, callID string, response string) error {
// First check if we have this human contact
hc, err := m.Store.GetHumanContact(callID)
if err != nil {
return fmt.Errorf("human contact not found: %w", err)
}
// Send response to API
if err := m.Client.RespondToHumanContact(ctx, callID, response); err != nil {
// Handle conflict error specially
conflictErr := m.handleConflictError(ctx, err, callID, "human_contact")
if errors.Is(conflictErr, ErrAlreadyResponded) {
// Return the already responded error directly
return conflictErr
}
if conflictErr != nil {
return fmt.Errorf("failed to respond to human contact: %w", conflictErr)
}
}
// Mark as responded in local store
if err := m.Store.MarkHumanContactResponded(callID); err != nil {
return fmt.Errorf("failed to update local state: %w", err)
}
// Publish event
if m.EventBus != nil && hc != nil {
m.EventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
// publishApprovalResolvedEvent publishes an event when an approval is resolved
func (m *manager) publishApprovalResolvedEvent(approval *store.Approval, approved bool, responseText string) {
if m.eventBus != nil {
event := bus.Event{
Type: bus.EventApprovalResolved,
Timestamp: time.Now(),
Data: map[string]interface{}{
"type": "human_contact",
"call_id": callID,
"run_id": hc.RunID,
"decision": "responded",
"response": response,
"approval_id": approval.ID,
"session_id": approval.SessionID,
"approved": approved,
"response_text": responseText,
},
})
}
m.eventBus.Publish(event)
}
return nil
}
// ReconcileApprovalsForSession reconciles approvals for a session after restart
func (m *DefaultManager) ReconcileApprovalsForSession(ctx context.Context, runID string) error {
if m.Poller == nil {
return nil // No poller configured
// updateSessionStatus updates the session status
func (m *manager) updateSessionStatus(ctx context.Context, sessionID, status string) error {
updates := store.SessionUpdate{
Status: &status,
LastActivityAt: &[]time.Time{time.Now()}[0],
}
return m.Poller.ReconcileApprovalsForSession(ctx, runID)
return m.store.UpdateSession(ctx, sessionID, updates)
}

View File

@@ -1,196 +0,0 @@
package approval
import (
"context"
"testing"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
"go.uber.org/mock/gomock"
)
func TestManager_SessionStatusTransitions(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tests := []struct {
name string
setupTest func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string // returns callID
expectedStatus string
}{
{
name: "approve_updates_session_to_running",
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
callID := "fc-approve-123"
runID := "run-approve-456"
sessionID := "sess-approve-789"
// Function call exists
fc := &humanlayer.FunctionCall{
CallID: callID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
},
}
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).AnyTimes()
// API approval succeeds
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
// Mark as responded
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
// Update approval status
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
// Get session by run_id
session := &store.Session{
ID: sessionID,
RunID: runID,
Status: store.SessionStatusWaitingInput,
}
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
// Expect status update to running
convStore.EXPECT().
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
if update.Status == nil || *update.Status != store.SessionStatusRunning {
t.Errorf("expected status update to running, got %v", update.Status)
}
return nil
})
return callID
},
expectedStatus: store.SessionStatusRunning,
},
{
name: "deny_updates_session_to_running",
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
callID := "fc-deny-123"
runID := "run-deny-456"
sessionID := "sess-deny-789"
// Function call exists
fc := &humanlayer.FunctionCall{
CallID: callID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
},
}
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(1)
// API denial succeeds
mockClient.EXPECT().DenyFunctionCall(gomock.Any(), callID, "test reason").Return(nil)
// Mark as responded
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
// Update approval status
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusDenied).Return(nil)
// Get session by run_id
session := &store.Session{
ID: sessionID,
RunID: runID,
Status: store.SessionStatusWaitingInput,
}
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
// Expect status update to running
convStore.EXPECT().
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
if update.Status == nil || *update.Status != store.SessionStatusRunning {
t.Errorf("expected status update to running, got %v", update.Status)
}
return nil
})
return callID
},
expectedStatus: store.SessionStatusRunning,
},
{
name: "no_status_update_for_non_waiting_session",
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
callID := "fc-nochange-123"
runID := "run-nochange-456"
sessionID := "sess-nochange-789"
// Function call exists
fc := &humanlayer.FunctionCall{
CallID: callID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
},
}
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(1)
// API approval succeeds
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
// Mark as responded
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
// Update approval status
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
// Get session - session is already running, not waiting
session := &store.Session{
ID: sessionID,
RunID: runID,
Status: store.SessionStatusRunning, // Already running
}
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
// No status update should happen since session is not waiting
return callID
},
expectedStatus: store.SessionStatusRunning,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a new controller for each test to isolate expectations
testCtrl := gomock.NewController(t)
defer testCtrl.Finish()
// Create mocks
mockClient := NewMockAPIClient(testCtrl)
mockStore := NewMockStore(testCtrl)
convStore := store.NewMockConversationStore(testCtrl)
// Create manager
manager := &DefaultManager{
Client: mockClient,
Store: mockStore,
ConversationStore: convStore,
}
// Setup test expectations and get callID
callID := tt.setupTest(mockClient, mockStore, convStore)
// Execute based on test type
ctx := context.Background()
var err error
if tt.name == "deny_updates_session_to_running" {
err = manager.DenyFunctionCall(ctx, callID, "test reason")
} else {
err = manager.ApproveFunctionCall(ctx, callID, "test comment")
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Expectations are verified by gomock
})
}
}

View File

@@ -0,0 +1,263 @@
package approval
import (
"context"
"encoding/json"
"strings"
"testing"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
func TestManager_CreateApproval(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockConversationStore(ctrl)
mockEventBus := bus.NewMockEventBus(ctrl)
manager := NewManager(mockStore, mockEventBus)
ctx := context.Background()
runID := "test-run-123"
sessionID := "test-session-456"
toolName := "Write"
toolInput := json.RawMessage(`{"file": "test.txt", "content": "hello"}`)
// Mock getting session by run ID
mockStore.EXPECT().GetSessionByRunID(ctx, runID).Return(&store.Session{
ID: sessionID,
RunID: runID,
}, nil)
// Mock creating approval
mockStore.EXPECT().CreateApproval(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, approval *store.Approval) error {
assert.Equal(t, runID, approval.RunID)
assert.Equal(t, sessionID, approval.SessionID)
assert.Equal(t, store.ApprovalStatusLocalPending, approval.Status)
assert.Equal(t, toolName, approval.ToolName)
assert.Equal(t, toolInput, approval.ToolInput)
assert.NotEmpty(t, approval.ID)
assert.True(t, strings.HasPrefix(approval.ID, "local-"))
return nil
})
// Mock correlation attempt - it's ok if it fails
mockStore.EXPECT().GetUncorrelatedPendingToolCall(ctx, sessionID, toolName).Return(nil, nil)
// Mock event publishing
mockEventBus.EXPECT().Publish(gomock.Any()).Do(func(event bus.Event) {
assert.Equal(t, bus.EventNewApproval, event.Type)
assert.Equal(t, sessionID, event.Data["session_id"])
assert.Equal(t, toolName, event.Data["tool_name"])
})
// Mock session status update
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
// Create approval
approvalID, err := manager.CreateApproval(ctx, runID, toolName, toolInput)
require.NoError(t, err)
assert.NotEmpty(t, approvalID)
assert.True(t, strings.HasPrefix(approvalID, "local-"))
}
func TestManager_CreateApproval_SessionNotFound(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockConversationStore(ctrl)
mockEventBus := bus.NewMockEventBus(ctrl)
manager := NewManager(mockStore, mockEventBus)
ctx := context.Background()
runID := "test-run-123"
toolName := "Write"
toolInput := json.RawMessage(`{"file": "test.txt"}`)
// Mock getting session by run ID - returns nil
mockStore.EXPECT().GetSessionByRunID(ctx, runID).Return(nil, nil)
// Create approval should fail
_, err := manager.CreateApproval(ctx, runID, toolName, toolInput)
assert.Error(t, err)
assert.Contains(t, err.Error(), "session not found")
}
func TestManager_GetPendingApprovals(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockConversationStore(ctrl)
mockEventBus := bus.NewMockEventBus(ctrl)
manager := NewManager(mockStore, mockEventBus)
ctx := context.Background()
sessionID := "test-session-456"
expectedApprovals := []*store.Approval{
{
ID: "local-approval-1",
SessionID: sessionID,
Status: store.ApprovalStatusLocalPending,
ToolName: "Write",
},
{
ID: "local-approval-2",
SessionID: sessionID,
Status: store.ApprovalStatusLocalPending,
ToolName: "Execute",
},
}
mockStore.EXPECT().GetPendingApprovals(ctx, sessionID).Return(expectedApprovals, nil)
approvals, err := manager.GetPendingApprovals(ctx, sessionID)
require.NoError(t, err)
assert.Equal(t, expectedApprovals, approvals)
}
func TestManager_ApproveToolCall(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockConversationStore(ctrl)
mockEventBus := bus.NewMockEventBus(ctrl)
manager := NewManager(mockStore, mockEventBus)
ctx := context.Background()
approvalID := "local-approval-123"
sessionID := "test-session-456"
comment := "Looks good!"
approval := &store.Approval{
ID: approvalID,
SessionID: sessionID,
Status: store.ApprovalStatusLocalPending,
ToolName: "Write",
}
// Mock getting approval
mockStore.EXPECT().GetApproval(ctx, approvalID).Return(approval, nil)
// Mock updating approval response
mockStore.EXPECT().UpdateApprovalResponse(ctx, approvalID, store.ApprovalStatusLocalApproved, comment).Return(nil)
// Mock updating approval status in conversation events
mockStore.EXPECT().UpdateApprovalStatus(ctx, approvalID, store.ApprovalStatusApproved).Return(nil)
// Mock event publishing
mockEventBus.EXPECT().Publish(gomock.Any()).Do(func(event bus.Event) {
assert.Equal(t, bus.EventApprovalResolved, event.Type)
assert.Equal(t, approvalID, event.Data["approval_id"])
assert.Equal(t, sessionID, event.Data["session_id"])
assert.Equal(t, true, event.Data["approved"])
assert.Equal(t, comment, event.Data["response_text"])
})
// Mock session status update
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
err := manager.ApproveToolCall(ctx, approvalID, comment)
require.NoError(t, err)
}
func TestManager_DenyToolCall(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockConversationStore(ctrl)
mockEventBus := bus.NewMockEventBus(ctrl)
manager := NewManager(mockStore, mockEventBus)
ctx := context.Background()
approvalID := "local-approval-123"
sessionID := "test-session-456"
reason := "Not safe to execute"
approval := &store.Approval{
ID: approvalID,
SessionID: sessionID,
Status: store.ApprovalStatusLocalPending,
ToolName: "Execute",
}
// Mock getting approval
mockStore.EXPECT().GetApproval(ctx, approvalID).Return(approval, nil)
// Mock updating approval response
mockStore.EXPECT().UpdateApprovalResponse(ctx, approvalID, store.ApprovalStatusLocalDenied, reason).Return(nil)
// Mock updating approval status in conversation events
mockStore.EXPECT().UpdateApprovalStatus(ctx, approvalID, store.ApprovalStatusDenied).Return(nil)
// Mock event publishing
mockEventBus.EXPECT().Publish(gomock.Any()).Do(func(event bus.Event) {
assert.Equal(t, bus.EventApprovalResolved, event.Type)
assert.Equal(t, approvalID, event.Data["approval_id"])
assert.Equal(t, sessionID, event.Data["session_id"])
assert.Equal(t, false, event.Data["approved"])
assert.Equal(t, reason, event.Data["response_text"])
})
// Mock session status update
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
err := manager.DenyToolCall(ctx, approvalID, reason)
require.NoError(t, err)
}
func TestManager_CorrelateApproval(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockConversationStore(ctrl)
mockEventBus := bus.NewMockEventBus(ctrl)
manager := NewManager(mockStore, mockEventBus)
ctx := context.Background()
runID := "test-run-123"
sessionID := "test-session-456"
toolName := "Write"
toolInput := json.RawMessage(`{"file": "test.txt"}`)
// Mock getting session by run ID
mockStore.EXPECT().GetSessionByRunID(ctx, runID).Return(&store.Session{
ID: sessionID,
RunID: runID,
}, nil)
// Mock creating approval
mockStore.EXPECT().CreateApproval(ctx, gomock.Any()).Return(nil)
// Mock successful correlation
pendingToolCall := &store.ConversationEvent{
ID: 123,
ToolID: "tool-123",
ToolName: toolName,
}
mockStore.EXPECT().GetUncorrelatedPendingToolCall(ctx, sessionID, toolName).Return(pendingToolCall, nil)
// Mock correlating by tool ID
mockStore.EXPECT().CorrelateApprovalByToolID(ctx, sessionID, "tool-123", gomock.Any()).Return(nil)
// Mock event publishing
mockEventBus.EXPECT().Publish(gomock.Any())
// Mock session status update
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
// Create approval (which will attempt correlation)
approvalID, err := manager.CreateApproval(ctx, runID, toolName, toolInput)
require.NoError(t, err)
assert.NotEmpty(t, approvalID)
}

View File

@@ -1,561 +0,0 @@
package approval
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
// Poller polls the HumanLayer API for pending approvals
type Poller struct {
client APIClient
store Store
conversationStore store.ConversationStore
eventBus bus.EventBus
interval time.Duration
maxBackoff time.Duration
backoffFactor float64
mu sync.Mutex
cancel context.CancelFunc
failureCount int
}
// NewPoller creates a new approval poller
func NewPoller(client APIClient, store Store, conversationStore store.ConversationStore, interval time.Duration, eventBus bus.EventBus) *Poller {
if interval <= 0 {
interval = 5 * time.Second
}
return &Poller{
client: client,
store: store,
conversationStore: conversationStore,
eventBus: eventBus,
interval: interval,
maxBackoff: 5 * time.Minute,
backoffFactor: 2.0,
failureCount: 0,
}
}
// Start begins polling for approvals
func (p *Poller) Start(ctx context.Context) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.cancel != nil {
return fmt.Errorf("poller already started")
}
ctx, cancel := context.WithCancel(ctx)
p.cancel = cancel
go p.pollLoop(ctx)
slog.Info("approval poller started", "interval", p.interval)
return nil
}
// Stop stops the polling loop
func (p *Poller) Stop() {
p.mu.Lock()
defer p.mu.Unlock()
if p.cancel != nil {
p.cancel()
p.cancel = nil
slog.Info("approval poller stopped")
}
}
// IsRunning returns true if the poller is currently running
func (p *Poller) IsRunning() bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.cancel != nil
}
// pollLoop continuously polls for approvals
func (p *Poller) pollLoop(ctx context.Context) {
// Poll immediately on start
p.poll(ctx)
for {
// Calculate next poll interval based on failure count
interval := p.calculateInterval()
select {
case <-ctx.Done():
return
case <-time.After(interval):
p.poll(ctx)
}
}
}
// calculateInterval returns the next polling interval with exponential backoff
func (p *Poller) calculateInterval() time.Duration {
p.mu.Lock()
defer p.mu.Unlock()
if p.failureCount == 0 {
return p.interval
}
// Calculate backoff: interval * (backoffFactor ^ failureCount)
backoff := float64(p.interval)
for i := 0; i < p.failureCount; i++ {
backoff *= p.backoffFactor
if time.Duration(backoff) > p.maxBackoff {
return p.maxBackoff
}
}
return time.Duration(backoff)
}
// poll fetches and stores pending approvals
func (p *Poller) poll(ctx context.Context) {
// Create a timeout context for this poll
pollCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
var hadError bool
// Fetch function calls
functionCalls, err := p.client.GetPendingFunctionCalls(pollCtx)
if err != nil {
slog.Error("failed to fetch function calls", "error", err)
hadError = true
} else {
// Reconcile with cached state
p.reconcileFunctionCalls(pollCtx, functionCalls)
}
// Fetch human contacts
humanContacts, err := p.client.GetPendingHumanContacts(pollCtx)
if err != nil {
slog.Error("failed to fetch human contacts", "error", err)
hadError = true
} else {
// Reconcile with cached state
p.reconcileHumanContacts(pollCtx, humanContacts)
}
// Update failure count based on results
p.updateFailureCount(hadError)
}
// reconcileFunctionCalls reconciles fetched function calls with cached state
func (p *Poller) reconcileFunctionCalls(ctx context.Context, functionCalls []humanlayer.FunctionCall) {
// Get all cached function calls
cachedCalls, err := p.store.GetAllCachedFunctionCalls()
if err != nil {
slog.Error("failed to get cached function calls", "error", err)
return
}
// Create a map of fetched calls for quick lookup
fetchedMap := make(map[string]*humanlayer.FunctionCall)
for i := range functionCalls {
fc := &functionCalls[i]
fetchedMap[fc.CallID] = fc
}
// Check for calls that were resolved externally
removedCount := 0
for _, cached := range cachedCalls {
if _, exists := fetchedMap[cached.CallID]; !exists {
// This approval is no longer pending, it was resolved externally
slog.Info("detected externally resolved function call",
"call_id", cached.CallID,
"run_id", cached.RunID)
// Remove from local store
if err := p.store.RemoveFunctionCall(cached.CallID); err != nil {
slog.Error("failed to remove resolved function call", "error", err)
} else {
removedCount++
}
// Update database status if we have a conversation store
if p.conversationStore != nil {
if err := p.conversationStore.UpdateApprovalStatus(ctx, cached.CallID, store.ApprovalStatusResolved); err != nil {
slog.Error("failed to update approval status in database", "error", err)
}
// Update session status back to running if it was waiting
if cached.RunID != "" {
session, err := p.conversationStore.GetSessionByRunID(ctx, cached.RunID)
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
runningStatus := store.SessionStatusRunning
update := store.SessionUpdate{
Status: &runningStatus,
}
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status", "error", err)
} else {
// Publish session status change event
if p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventSessionStatusChanged,
Data: map[string]interface{}{
"session_id": session.ID,
"run_id": cached.RunID,
"old_status": string(store.SessionStatusWaitingInput),
"new_status": string(store.SessionStatusRunning),
},
})
}
}
}
}
}
// Publish event
if p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
Data: map[string]interface{}{
"type": "function_call",
"call_id": cached.CallID,
"run_id": cached.RunID,
"resolved_externally": true,
},
})
}
}
}
// Store new/updated function calls
newCount := 0
for _, fc := range functionCalls {
// Check if this is a new approval
if existing, err := p.store.GetFunctionCall(fc.CallID); err != nil || existing == nil {
newCount++
}
if err := p.store.StoreFunctionCall(fc); err != nil {
slog.Error("failed to store function call", "call_id", fc.CallID, "error", err)
}
// Correlate with tool calls in the database
if p.conversationStore != nil && fc.RunID != "" {
p.correlateApproval(ctx, fc)
}
}
slog.Debug("reconciled function calls",
"fetched", len(functionCalls),
"new", newCount,
"removed", removedCount)
// Publish event if we have new approvals
if newCount > 0 && p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventNewApproval,
Data: map[string]interface{}{
"type": "function_call",
"count": newCount,
"total": len(functionCalls),
},
})
}
}
// reconcileHumanContacts reconciles fetched human contacts with cached state
func (p *Poller) reconcileHumanContacts(ctx context.Context, humanContacts []humanlayer.HumanContact) {
// Get all cached human contacts
cachedContacts, err := p.store.GetAllCachedHumanContacts()
if err != nil {
slog.Error("failed to get cached human contacts", "error", err)
return
}
// Create a map of fetched contacts for quick lookup
fetchedMap := make(map[string]*humanlayer.HumanContact)
for i := range humanContacts {
hc := &humanContacts[i]
fetchedMap[hc.CallID] = hc
}
// Check for contacts that were resolved externally
removedCount := 0
for _, cached := range cachedContacts {
if _, exists := fetchedMap[cached.CallID]; !exists {
// This contact is no longer pending, it was resolved externally
slog.Info("detected externally resolved human contact",
"call_id", cached.CallID,
"run_id", cached.RunID)
// Remove from local store
if err := p.store.RemoveHumanContact(cached.CallID); err != nil {
slog.Error("failed to remove resolved human contact", "error", err)
} else {
removedCount++
}
// Publish event
if p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
Data: map[string]interface{}{
"type": "human_contact",
"call_id": cached.CallID,
"run_id": cached.RunID,
"resolved_externally": true,
},
})
}
}
}
// Store new/updated human contacts
newCount := 0
for _, hc := range humanContacts {
// Check if this is a new approval
if existing, err := p.store.GetHumanContact(hc.CallID); err != nil || existing == nil {
newCount++
}
if err := p.store.StoreHumanContact(hc); err != nil {
slog.Error("failed to store human contact", "call_id", hc.CallID, "error", err)
}
}
slog.Debug("reconciled human contacts",
"fetched", len(humanContacts),
"new", newCount,
"removed", removedCount)
// Publish event if we have new approvals
if newCount > 0 && p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventNewApproval,
Data: map[string]interface{}{
"type": "human_contact",
"count": newCount,
"total": len(humanContacts),
},
})
}
}
// ReconcileApprovalsForSession checks for any approvals that might belong to a session
// This is useful when a session restarts and needs to reclaim pending approvals
func (p *Poller) ReconcileApprovalsForSession(ctx context.Context, runID string) error {
// Get all cached function calls
functionCalls, err := p.store.GetAllCachedFunctionCalls()
if err != nil {
return fmt.Errorf("failed to get cached function calls: %w", err)
}
// Check each one to see if it matches this run_id
for _, fc := range functionCalls {
if fc.RunID == runID {
// Re-correlate this approval
p.correlateApproval(ctx, fc)
}
}
// Also check human contacts
humanContacts, err := p.store.GetAllCachedHumanContacts()
if err != nil {
return fmt.Errorf("failed to get cached human contacts: %w", err)
}
for _, hc := range humanContacts {
if hc.RunID == runID {
slog.Info("found orphaned human contact for session",
"call_id", hc.CallID,
"run_id", runID)
}
}
return nil
}
// updateFailureCount updates the failure count for backoff calculation
func (p *Poller) updateFailureCount(hadError bool) {
p.mu.Lock()
defer p.mu.Unlock()
if hadError {
p.failureCount++
nextInterval := p.calculateIntervalLocked()
slog.Warn("poll failed, backing off",
"failure_count", p.failureCount,
"next_interval", nextInterval)
} else {
if p.failureCount > 0 {
slog.Info("poll succeeded, resetting backoff")
p.failureCount = 0
}
}
}
// calculateIntervalLocked returns the next interval (must be called with lock held)
func (p *Poller) calculateIntervalLocked() time.Duration {
if p.failureCount == 0 {
return p.interval
}
// Calculate backoff: interval * (backoffFactor ^ failureCount)
backoff := float64(p.interval)
for i := 0; i < p.failureCount; i++ {
backoff *= p.backoffFactor
if time.Duration(backoff) > p.maxBackoff {
return p.maxBackoff
}
}
return time.Duration(backoff)
}
// correlateApproval attempts to match an approval with a pending tool call
func (p *Poller) correlateApproval(ctx context.Context, fc humanlayer.FunctionCall) {
// Find the session with this run_id
session, err := p.conversationStore.GetSessionByRunID(ctx, fc.RunID)
if err != nil {
slog.Error("failed to get session for correlation",
"run_id", fc.RunID,
"error", err)
return
}
if session == nil {
// This is expected for approvals that aren't from our Claude sessions
slog.Debug("no matching session for approval",
"run_id", fc.RunID,
"approval_id", fc.CallID)
return
}
toolName := fc.Spec.Fn
// Try to find an uncorrelated pending tool call for this session and tool
// Use the specialized method that only finds tool calls without approvals
var toolCall *store.ConversationEvent
// Check if the store has the GetUncorrelatedPendingToolCall method
if uncorrelatedStore, ok := p.conversationStore.(*store.SQLiteStore); ok {
toolCall, err = uncorrelatedStore.GetUncorrelatedPendingToolCall(ctx, session.ID, toolName)
} else {
// Fallback to regular GetPendingToolCall
toolCall, err = p.conversationStore.GetPendingToolCall(ctx, session.ID, toolName)
}
if err != nil || toolCall == nil {
// For continued sessions, also check the parent session's tool calls
if session.ParentSessionID != "" {
parentToolCall, err := p.conversationStore.GetPendingToolCall(ctx, session.ParentSessionID, toolName)
if err == nil && parentToolCall != nil {
// Found in parent session - correlate with parent
if err := p.conversationStore.CorrelateApproval(ctx, session.ParentSessionID, toolName, fc.CallID); err != nil {
slog.Error("failed to correlate approval with parent session tool call",
"approval_id", fc.CallID,
"parent_session_id", session.ParentSessionID,
"tool_name", toolName,
"error", err)
return
}
slog.Info("correlated approval with parent session tool call",
"approval_id", fc.CallID,
"parent_session_id", session.ParentSessionID,
"current_session_id", session.ID,
"tool_name", toolName,
"run_id", fc.RunID)
// Update current session status to waiting_input
waitingStatus := store.SessionStatusWaitingInput
update := store.SessionUpdate{
Status: &waitingStatus,
}
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status to waiting_input",
"session_id", session.ID,
"error", err)
}
// Publish session status change event
if p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventSessionStatusChanged,
Data: map[string]interface{}{
"session_id": session.ID,
"run_id": fc.RunID,
"old_status": string(store.SessionStatusRunning),
"new_status": string(store.SessionStatusWaitingInput),
},
})
}
return
}
}
slog.Debug("no matching pending tool call for approval",
"session_id", session.ID,
"run_id", fc.RunID,
"tool_name", toolName,
"approval_id", fc.CallID)
return
}
// Found a matching tool call in one of our sessions - correlate it
if err := p.conversationStore.CorrelateApproval(ctx, session.ID, toolName, fc.CallID); err != nil {
slog.Error("failed to correlate approval with tool call",
"approval_id", fc.CallID,
"session_id", session.ID,
"tool_name", toolName,
"error", err)
return
}
slog.Info("correlated approval with tool call",
"approval_id", fc.CallID,
"session_id", session.ID,
"tool_name", toolName,
"run_id", fc.RunID)
// Publish event to notify that approval has been correlated
if p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
Data: map[string]interface{}{
"approval_id": fc.CallID,
"session_id": session.ID,
"tool_name": toolName,
"run_id": fc.RunID,
"status": "correlated",
},
})
}
// Update session status to waiting_input
waitingStatus := store.SessionStatusWaitingInput
update := store.SessionUpdate{
Status: &waitingStatus,
}
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status to waiting_input",
"session_id", session.ID,
"error", err)
}
// Publish session status change event
if p.eventBus != nil {
p.eventBus.Publish(bus.Event{
Type: bus.EventSessionStatusChanged,
Data: map[string]interface{}{
"session_id": session.ID,
"run_id": fc.RunID,
"old_status": string(store.SessionStatusRunning),
"new_status": string(store.SessionStatusWaitingInput),
},
})
}
}

View File

@@ -1,176 +0,0 @@
package approval
import (
"context"
"errors"
"testing"
"time"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
"go.uber.org/mock/gomock"
)
func TestPoller_Poll(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// Create mock client with test data
mockClient := NewMockAPIClient(ctrl)
mockStore := NewMockStore(ctrl)
testFunctionCalls := []humanlayer.FunctionCall{
{CallID: "fc-1", RunID: "run-1"},
{CallID: "fc-2", RunID: "run-2"},
}
testHumanContacts := []humanlayer.HumanContact{
{CallID: "hc-1", RunID: "run-1"},
}
// Set expectations
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return(testFunctionCalls, nil)
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return(testHumanContacts, nil)
// Expect reconciliation calls
mockStore.EXPECT().GetAllCachedFunctionCalls().Return([]humanlayer.FunctionCall{}, nil)
mockStore.EXPECT().GetAllCachedHumanContacts().Return([]humanlayer.HumanContact{}, nil)
// Expect checks for existing approvals
mockStore.EXPECT().GetFunctionCall("fc-1").Return(nil, errors.New("not found"))
mockStore.EXPECT().GetFunctionCall("fc-2").Return(nil, errors.New("not found"))
mockStore.EXPECT().GetHumanContact("hc-1").Return(nil, errors.New("not found"))
mockStore.EXPECT().StoreFunctionCall(testFunctionCalls[0]).Return(nil)
mockStore.EXPECT().StoreFunctionCall(testFunctionCalls[1]).Return(nil)
mockStore.EXPECT().StoreHumanContact(testHumanContacts[0]).Return(nil)
// Create poller with short interval for testing
poller := &Poller{
client: mockClient,
store: mockStore,
interval: 10 * time.Millisecond,
maxBackoff: 100 * time.Millisecond,
backoffFactor: 2.0,
}
// Poll once
ctx := context.Background()
poller.poll(ctx)
// Verify no failure count
if poller.failureCount != 0 {
t.Errorf("expected failure count 0, got %d", poller.failureCount)
}
}
func TestPoller_Backoff(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockClient := NewMockAPIClient(ctrl)
mockStore := NewMockStore(ctrl)
apiErr := errors.New("API error")
// First two calls fail
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return(nil, apiErr).Times(2)
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return(nil, apiErr).Times(2)
// Third call succeeds
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return([]humanlayer.FunctionCall{}, nil)
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return([]humanlayer.HumanContact{}, nil)
// Expect reconciliation calls for successful poll
mockStore.EXPECT().GetAllCachedFunctionCalls().Return([]humanlayer.FunctionCall{}, nil)
mockStore.EXPECT().GetAllCachedHumanContacts().Return([]humanlayer.HumanContact{}, nil)
poller := &Poller{
client: mockClient,
store: mockStore,
interval: 10 * time.Millisecond,
maxBackoff: 100 * time.Millisecond,
backoffFactor: 2.0,
}
ctx := context.Background()
// First poll should fail and increment failure count
poller.poll(ctx)
if poller.failureCount != 1 {
t.Errorf("expected failure count 1, got %d", poller.failureCount)
}
// Calculate expected interval (10ms * 2^1 = 20ms)
interval := poller.calculateInterval()
if interval != 20*time.Millisecond {
t.Errorf("expected interval 20ms, got %v", interval)
}
// Second poll should fail and increment again
poller.poll(ctx)
if poller.failureCount != 2 {
t.Errorf("expected failure count 2, got %d", poller.failureCount)
}
// Calculate expected interval (10ms * 2^2 = 40ms)
interval = poller.calculateInterval()
if interval != 40*time.Millisecond {
t.Errorf("expected interval 40ms, got %v", interval)
}
// Test max backoff
poller.failureCount = 10
interval = poller.calculateInterval()
if interval != poller.maxBackoff {
t.Errorf("expected max backoff %v, got %v", poller.maxBackoff, interval)
}
// Test reset on success
poller.poll(ctx)
if poller.failureCount != 0 {
t.Errorf("expected failure count reset to 0, got %d", poller.failureCount)
}
}
func TestPoller_StartStop(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockClient := NewMockAPIClient(ctrl)
mockStore := NewMockStore(ctrl)
// Expect at least 2 polls during the test duration
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return([]humanlayer.FunctionCall{}, nil).MinTimes(2)
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return([]humanlayer.HumanContact{}, nil).MinTimes(2)
// Expect reconciliation calls
mockStore.EXPECT().GetAllCachedFunctionCalls().Return([]humanlayer.FunctionCall{}, nil).MinTimes(2)
mockStore.EXPECT().GetAllCachedHumanContacts().Return([]humanlayer.HumanContact{}, nil).MinTimes(2)
poller := NewPoller(mockClient, mockStore, nil, 50*time.Millisecond, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start poller
err := poller.Start(ctx)
if err != nil {
t.Fatalf("failed to start poller: %v", err)
}
// Wait for a few polls
time.Sleep(120 * time.Millisecond)
// Try to start again while running should fail
err = poller.Start(ctx)
if err == nil {
t.Error("expected error starting poller twice")
}
// Stop poller
poller.Stop()
// Verify it's stopped
if poller.IsRunning() {
t.Error("expected poller to be stopped")
}
}

View File

@@ -2,62 +2,21 @@ package approval
import (
"context"
"encoding/json"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
"github.com/humanlayer/humanlayer/hld/store"
)
// PendingApproval wraps either a function call or human contact
type PendingApproval struct {
Type string `json:"type"` // "function_call" or "human_contact"
FunctionCall *humanlayer.FunctionCall `json:"function_call,omitempty"`
HumanContact *humanlayer.HumanContact `json:"human_contact,omitempty"`
}
// Manager defines the interface for managing approvals
// Manager defines the interface for managing local approvals
type Manager interface {
// Lifecycle methods
Start(ctx context.Context) error
Stop()
// Create a new approval
CreateApproval(ctx context.Context, runID, toolName string, toolInput json.RawMessage) (string, error)
// Retrieval methods
GetPendingApprovals(sessionID string) ([]PendingApproval, error)
GetPendingApprovalsByRunID(runID string) ([]PendingApproval, error)
GetPendingApprovals(ctx context.Context, sessionID string) ([]*store.Approval, error)
GetApproval(ctx context.Context, id string) (*store.Approval, error)
// Decision methods
ApproveFunctionCall(ctx context.Context, callID string, comment string) error
DenyFunctionCall(ctx context.Context, callID string, reason string) error
RespondToHumanContact(ctx context.Context, callID string, response string) error
// Recovery methods
ReconcileApprovalsForSession(ctx context.Context, runID string) error
}
// Store manages approval storage and correlation
type Store interface {
// Storage methods
StoreFunctionCall(fc humanlayer.FunctionCall) error
StoreHumanContact(hc humanlayer.HumanContact) error
// Retrieval methods
GetFunctionCall(callID string) (*humanlayer.FunctionCall, error)
GetHumanContact(callID string) (*humanlayer.HumanContact, error)
GetAllPending() ([]PendingApproval, error)
GetPendingByRunID(runID string) ([]PendingApproval, error)
GetAllCachedFunctionCalls() ([]humanlayer.FunctionCall, error)
GetAllCachedHumanContacts() ([]humanlayer.HumanContact, error)
// Update methods
MarkFunctionCallResponded(callID string) error
MarkHumanContactResponded(callID string) error
RemoveFunctionCall(callID string) error
RemoveHumanContact(callID string) error
}
// APIClient defines the interface for interacting with the HumanLayer API
type APIClient interface {
GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error)
GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error)
ApproveFunctionCall(ctx context.Context, callID string, comment string) error
DenyFunctionCall(ctx context.Context, callID string, reason string) error
RespondToHumanContact(ctx context.Context, callID string, response string) error
ApproveToolCall(ctx context.Context, id string, comment string) error
DenyToolCall(ctx context.Context, id string, reason string) error
}

View File

@@ -8,8 +8,8 @@ import (
"sync/atomic"
"time"
"github.com/humanlayer/humanlayer/hld/approval"
"github.com/humanlayer/humanlayer/hld/rpc"
"github.com/humanlayer/humanlayer/hld/store"
)
// client provides a JSON-RPC 2.0 client for communicating with the HumanLayer daemon
@@ -266,7 +266,7 @@ func (c *client) ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueS
}
// FetchApprovals fetches pending approvals from the daemon
func (c *client) FetchApprovals(sessionID string) ([]approval.PendingApproval, error) {
func (c *client) FetchApprovals(sessionID string) ([]*store.Approval, error) {
req := rpc.FetchApprovalsRequest{
SessionID: sessionID,
}
@@ -277,11 +277,11 @@ func (c *client) FetchApprovals(sessionID string) ([]approval.PendingApproval, e
return resp.Approvals, nil
}
// SendDecision sends a decision (approve/deny/respond) for an approval
func (c *client) SendDecision(callID, approvalType, decision, comment string) error {
// SendDecision sends a decision (approve/deny) for an approval
func (c *client) SendDecision(approvalID, decision, comment string) error {
req := rpc.SendDecisionRequest{
CallID: callID,
Type: approvalType,
CallID: approvalID, // Using CallID for backward compatibility
Type: "function_call", // Always function_call for local approvals
Decision: decision,
Comment: comment,
}
@@ -295,25 +295,17 @@ func (c *client) SendDecision(callID, approvalType, decision, comment string) er
return nil
}
// ApproveFunctionCall approves a function call with an optional comment
func (c *client) ApproveFunctionCall(callID, comment string) error {
return c.SendDecision(callID, string(rpc.ApprovalTypeFunctionCall), string(rpc.DecisionApprove), comment)
// ApproveToolCall approves a tool call with an optional comment
func (c *client) ApproveToolCall(approvalID, comment string) error {
return c.SendDecision(approvalID, "approve", comment)
}
// DenyFunctionCall denies a function call with a required reason
func (c *client) DenyFunctionCall(callID, reason string) error {
// DenyToolCall denies a tool call with a required reason
func (c *client) DenyToolCall(approvalID, reason string) error {
if reason == "" {
return fmt.Errorf("reason is required when denying a function call")
return fmt.Errorf("reason is required when denying a tool call")
}
return c.SendDecision(callID, string(rpc.ApprovalTypeFunctionCall), string(rpc.DecisionDeny), reason)
}
// RespondToHumanContact responds to a human contact request
func (c *client) RespondToHumanContact(callID, response string) error {
if response == "" {
return fmt.Errorf("response is required for human contact")
}
return c.SendDecision(callID, string(rpc.ApprovalTypeHumanContact), string(rpc.DecisionRespond), response)
return c.SendDecision(approvalID, "deny", reason)
}
// GetConversation fetches the conversation history for a session

View File

@@ -10,10 +10,9 @@ import (
"testing"
"time"
"github.com/humanlayer/humanlayer/hld/approval"
"github.com/humanlayer/humanlayer/hld/internal/testutil"
"github.com/humanlayer/humanlayer/hld/rpc"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
"github.com/humanlayer/humanlayer/hld/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -143,30 +142,25 @@ func TestClient_FetchApprovals(t *testing.T) {
server, socketPath := newMockRPCServer(t)
defer server.stop()
// Create test approvals that match what the TUI expects
testApprovals := []approval.PendingApproval{
// Create test approvals with new local format
testApprovals := []*store.Approval{
{
Type: "function_call",
FunctionCall: &humanlayer.FunctionCall{
CallID: "fc-123",
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
Kwargs: map[string]interface{}{
"arg": "value",
},
},
RunID: "run-123",
},
ID: "local-123",
RunID: "run-123",
SessionID: "session-123",
Status: store.ApprovalStatusLocalPending,
CreatedAt: time.Now(),
ToolName: "test_function",
ToolInput: json.RawMessage(`{"arg": "value"}`),
},
{
Type: "human_contact",
HumanContact: &humanlayer.HumanContact{
CallID: "hc-456",
Spec: humanlayer.HumanContactSpec{
Msg: "Need help with something",
},
RunID: "run-456",
},
ID: "local-456",
RunID: "run-456",
SessionID: "session-456",
Status: store.ApprovalStatusLocalPending,
CreatedAt: time.Now(),
ToolName: "another_function",
ToolInput: json.RawMessage(`{"msg": "test message"}`),
},
}
@@ -187,15 +181,15 @@ func TestClient_FetchApprovals(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, approvals, 2)
// Verify function call
assert.Equal(t, "function_call", approvals[0].Type)
assert.NotNil(t, approvals[0].FunctionCall)
assert.Equal(t, "fc-123", approvals[0].FunctionCall.CallID)
// Verify first approval
assert.Equal(t, "local-123", approvals[0].ID)
assert.Equal(t, "test_function", approvals[0].ToolName)
assert.Equal(t, store.ApprovalStatusLocalPending, approvals[0].Status)
// Verify human contact
assert.Equal(t, "human_contact", approvals[1].Type)
assert.NotNil(t, approvals[1].HumanContact)
assert.Equal(t, "hc-456", approvals[1].HumanContact.CallID)
// Verify second approval
assert.Equal(t, "local-456", approvals[1].ID)
assert.Equal(t, "another_function", approvals[1].ToolName)
assert.Equal(t, store.ApprovalStatusLocalPending, approvals[1].Status)
}
func TestClient_SendDecision(t *testing.T) {
@@ -227,15 +221,18 @@ func TestClient_SendDecision(t *testing.T) {
defer func() { _ = c.Close() }()
// Test approve
err = c.SendDecision("test-123", "function_call", "approve", "looks good")
err = c.SendDecision("test-123", "approve", "looks good")
assert.NoError(t, err)
// Test deny
err = c.SendDecision("test-456", "function_call", "deny", "too risky")
err = c.SendDecision("test-456", "deny", "too risky")
assert.NoError(t, err)
// Test respond
err = c.SendDecision("test-789", "human_contact", "respond", "here is my response")
// Test the convenience methods
err = c.ApproveToolCall("test-789", "approved")
assert.NoError(t, err)
err = c.DenyToolCall("test-890", "not allowed")
assert.NoError(t, err)
}

View File

@@ -3,8 +3,8 @@ package client
import (
"time"
"github.com/humanlayer/humanlayer/hld/approval"
"github.com/humanlayer/humanlayer/hld/rpc"
"github.com/humanlayer/humanlayer/hld/store"
)
// Client defines the interface for communicating with the HumanLayer daemon
@@ -25,17 +25,14 @@ type Client interface {
ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueSessionResponse, error)
// FetchApprovals fetches pending approvals from the daemon
FetchApprovals(sessionID string) ([]approval.PendingApproval, error)
FetchApprovals(sessionID string) ([]*store.Approval, error)
// SendDecision sends a decision (approve/deny/respond) for an approval
SendDecision(callID, approvalType, decision, comment string) error
// SendDecision sends a decision (approve/deny) for an approval
SendDecision(approvalID, decision, comment string) error
// Type-safe approval methods for function calls
ApproveFunctionCall(callID, comment string) error
DenyFunctionCall(callID, reason string) error
// Type-safe approval methods for human contacts
RespondToHumanContact(callID, response string) error
// Type-safe approval methods
ApproveToolCall(approvalID, comment string) error
DenyToolCall(approvalID, reason string) error
// GetConversation fetches the conversation history for a session
GetConversation(sessionID string) (*rpc.GetConversationResponse, error)

View File

@@ -86,28 +86,10 @@ func New() (*Daemon, error) {
return nil, fmt.Errorf("failed to create session manager: %w", err)
}
// Create approval manager if API key is configured
var approvalManager approval.Manager
if cfg.APIKey != "" {
slog.Info("creating approval manager", "api_base_url", cfg.APIBaseURL)
approvalCfg := approval.Config{
APIKey: cfg.APIKey,
BaseURL: cfg.APIBaseURL,
// Use defaults for now, could add to daemon config later
}
approvalManager, err = approval.NewManager(approvalCfg, eventBus, conversationStore)
if err != nil {
return nil, fmt.Errorf("failed to create approval manager: %w", err)
}
slog.Debug("approval manager created successfully")
} else {
slog.Warn("no API key configured, approval features disabled")
}
// Set approval reconciler on session manager if approval manager exists
if approvalManager != nil {
sessionManager.SetApprovalReconciler(approvalManager)
}
// Always create local approval manager
slog.Info("creating local approval manager")
approvalManager := approval.NewManager(conversationStore, eventBus)
slog.Debug("local approval manager created successfully")
return &Daemon{
config: cfg,
@@ -172,23 +154,10 @@ func (d *Daemon) Run(ctx context.Context) error {
sessionHandlers := rpc.NewSessionHandlers(d.sessions, d.store)
sessionHandlers.Register(d.rpcServer)
// Always register approval handlers (even without API key)
// Register local approval handlers
approvalHandlers := rpc.NewApprovalHandlers(d.approvals, d.sessions)
approvalHandlers.Register(d.rpcServer)
// Start approval polling if approval manager is available
if d.approvals != nil {
if err := d.approvals.Start(ctx); err != nil {
_ = listener.Close()
return fmt.Errorf("failed to start approval poller: %w", err)
}
defer d.approvals.Stop()
slog.Info("approval polling started")
} else {
slog.Warn("approval manager not configured (no API key)")
}
slog.Info("daemon started", "socket", d.socketPath)
// Accept connections until context is cancelled

View File

@@ -7,118 +7,22 @@ import (
"encoding/json"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/humanlayer/humanlayer/hld/approval"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/config"
"github.com/humanlayer/humanlayer/hld/internal/testutil"
"github.com/humanlayer/humanlayer/hld/rpc"
"github.com/humanlayer/humanlayer/hld/session"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
// mockAPIClient is a test implementation of the HumanLayer API client
// This simulates a backend API for integration testing
type mockAPIClient struct {
functionCalls []humanlayer.FunctionCall
humanContacts []humanlayer.HumanContact
decisions map[string]string // call_id -> decision
mu sync.Mutex
}
func newMockAPIClient() *mockAPIClient {
return &mockAPIClient{
decisions: make(map[string]string),
}
}
func (m *mockAPIClient) GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Return only function calls that haven't been decided
var pending []humanlayer.FunctionCall
for _, fc := range m.functionCalls {
if _, decided := m.decisions[fc.CallID]; !decided {
pending = append(pending, fc)
}
}
return pending, nil
}
func (m *mockAPIClient) GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Return only contacts that haven't been decided
var pending []humanlayer.HumanContact
for _, hc := range m.humanContacts {
if _, decided := m.decisions[hc.CallID]; !decided {
pending = append(pending, hc)
}
}
return pending, nil
}
func (m *mockAPIClient) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.decisions[callID] = "approved"
return nil
}
func (m *mockAPIClient) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.decisions[callID] = "denied"
return nil
}
func (m *mockAPIClient) RespondToHumanContact(ctx context.Context, callID string, response string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.decisions[callID] = "responded"
return nil
}
func TestDaemonApprovalIntegration(t *testing.T) {
// Create test socket path
socketPath := testutil.SocketPath(t, "daemon-approval-test")
// Create mock API client with test data
mockClient := newMockAPIClient()
mockClient.functionCalls = []humanlayer.FunctionCall{
{
CallID: "fc-1",
RunID: "test-run-1",
Spec: humanlayer.FunctionCallSpec{
Fn: "dangerous_function",
Kwargs: map[string]interface{}{
"action": "delete_all",
},
},
},
{
CallID: "fc-2",
RunID: "test-run-2",
Spec: humanlayer.FunctionCallSpec{
Fn: "safe_function",
},
},
}
mockClient.humanContacts = []humanlayer.HumanContact{
{
CallID: "hc-1",
RunID: "test-run-1",
Spec: humanlayer.HumanContactSpec{
Msg: "Need human help",
},
},
}
// Create a minimal in-memory store for testing
testStore, err := store.NewSQLiteStore(":memory:")
if err != nil {
@@ -126,16 +30,16 @@ func TestDaemonApprovalIntegration(t *testing.T) {
}
defer testStore.Close()
// Create real approval components for integration testing
approvalStore := approval.NewMemoryStore()
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, nil)
// Create event bus for the approval manager
eventBus := bus.NewEventBus()
// We need to manually construct the manager with our test client
approvalManager := &approval.DefaultManager{
Client: mockClient,
Store: approvalStore,
Poller: poller,
ConversationStore: testStore,
// Create local approval manager
approvalManager := approval.NewManager(testStore, eventBus)
// Create session manager
sessionManager, err := session.NewManager(eventBus, testStore)
if err != nil {
t.Fatalf("failed to create session manager: %v", err)
}
// Create test daemon with approval manager
@@ -146,16 +50,11 @@ func TestDaemonApprovalIntegration(t *testing.T) {
},
socketPath: socketPath,
approvals: approvalManager,
sessions: sessionManager,
eventBus: eventBus,
store: testStore,
}
// Create session manager (we don't need real sessions for this test)
sessionManager, err := session.NewManager(nil, testStore)
if err != nil {
t.Fatalf("failed to create session manager: %v", err)
}
d.sessions = sessionManager
// Start daemon
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -178,82 +77,158 @@ func TestDaemonApprovalIntegration(t *testing.T) {
// Create RPC client
client := &rpcClient{conn: conn}
// Wait for poller to fetch initial data
time.Sleep(100 * time.Millisecond)
// Test scenario: Create local approvals and test RPC operations
// Test 1: Fetch all approvals
// First, create test sessions and approvals in the database
ctx2 := context.Background()
// Create test sessions
session1 := &store.Session{
ID: "test-session-1",
RunID: "test-run-1",
Query: "Test query 1",
Status: store.SessionStatusRunning,
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
}
if err := testStore.CreateSession(ctx2, session1); err != nil {
t.Fatalf("failed to create session 1: %v", err)
}
session2 := &store.Session{
ID: "test-session-2",
RunID: "test-run-2",
Query: "Test query 2",
Status: store.SessionStatusRunning,
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
}
if err := testStore.CreateSession(ctx2, session2); err != nil {
t.Fatalf("failed to create session 2: %v", err)
}
// Create approvals via the RPC interface
var createResp rpc.CreateApprovalResponse
// Approval 1
err = client.call("createApproval", rpc.CreateApprovalRequest{
RunID: "test-run-1",
ToolName: "dangerous_function",
ToolInput: json.RawMessage(`{"action": "delete_all"}`),
}, &createResp)
if err != nil {
t.Fatalf("failed to create approval 1: %v", err)
}
approval1ID := createResp.ApprovalID
// Approval 2
err = client.call("createApproval", rpc.CreateApprovalRequest{
RunID: "test-run-2",
ToolName: "safe_function",
ToolInput: json.RawMessage(`{"action": "read_only"}`),
}, &createResp)
if err != nil {
t.Fatalf("failed to create approval 2: %v", err)
}
approval2ID := createResp.ApprovalID
// Test 1: Fetch all approvals (should be empty without session filter)
var fetchResp rpc.FetchApprovalsResponse
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{}, &fetchResp)
if err != nil {
t.Fatalf("failed to fetch approvals: %v", err)
}
if len(fetchResp.Approvals) != 3 {
t.Errorf("expected 3 approvals, got %d", len(fetchResp.Approvals))
if len(fetchResp.Approvals) != 0 {
t.Errorf("expected 0 approvals without session filter, got %d", len(fetchResp.Approvals))
}
// Test 2: Approve a function call
// Test 2: Fetch approvals for session 1
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{
SessionID: "test-session-1",
}, &fetchResp)
if err != nil {
t.Fatalf("failed to fetch approvals for session 1: %v", err)
}
if len(fetchResp.Approvals) != 1 {
t.Errorf("expected 1 approval for session 1, got %d", len(fetchResp.Approvals))
}
// Test 3: Approve a function call
var decisionResp rpc.SendDecisionResponse
err = client.call("sendDecision", rpc.SendDecisionRequest{
CallID: "fc-1",
CallID: approval1ID,
Type: "function_call",
Decision: "approve",
Comment: "Looks good",
}, &decisionResp)
if err != nil {
t.Fatalf("failed to send decision: %v", err)
t.Fatalf("failed to send approval decision: %v", err)
}
if !decisionResp.Success {
t.Errorf("expected success, got error: %s", decisionResp.Error)
}
// Verify the decision was recorded
if mockClient.decisions["fc-1"] != "approved" {
t.Errorf("expected fc-1 to be approved, got %s", mockClient.decisions["fc-1"])
}
// Test 3: Deny a function call
// Test 4: Deny a function call
err = client.call("sendDecision", rpc.SendDecisionRequest{
CallID: "fc-2",
CallID: approval2ID,
Type: "function_call",
Decision: "deny",
Comment: "Too risky",
}, &decisionResp)
if err != nil {
t.Fatalf("failed to send decision: %v", err)
t.Fatalf("failed to send deny decision: %v", err)
}
if !decisionResp.Success {
t.Errorf("expected success, got error: %s", decisionResp.Error)
}
// Test 4: Respond to human contact
err = client.call("sendDecision", rpc.SendDecisionRequest{
CallID: "hc-1",
Type: "human_contact",
Decision: "respond",
Comment: "Here's the help you need",
}, &decisionResp)
if err != nil {
t.Fatalf("failed to send decision: %v", err)
}
if !decisionResp.Success {
t.Errorf("expected success, got error: %s", decisionResp.Error)
}
// Wait for next poll cycle
time.Sleep(100 * time.Millisecond)
// Test 5: Verify approvals are no longer pending
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{}, &fetchResp)
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{
SessionID: "test-session-1",
}, &fetchResp)
if err != nil {
t.Fatalf("failed to fetch approvals: %v", err)
}
if len(fetchResp.Approvals) != 0 {
t.Errorf("expected 0 pending approvals after decisions, got %d", len(fetchResp.Approvals))
t.Errorf("expected 0 pending approvals for session 1 after approval, got %d", len(fetchResp.Approvals))
}
// Test 6: Try to approve non-existent approval
err = client.call("sendDecision", rpc.SendDecisionRequest{
CallID: "non-existent",
Type: "function_call",
Decision: "approve",
Comment: "Should fail",
}, &decisionResp)
if err != nil {
t.Fatalf("failed to send decision: %v", err)
}
if decisionResp.Success {
t.Error("expected failure for non-existent approval")
}
// Test 7: Human contact is no longer supported
err = client.call("sendDecision", rpc.SendDecisionRequest{
CallID: "some-id",
Type: "human_contact",
Decision: "respond",
Comment: "Should fail",
}, &decisionResp)
if err != nil {
t.Fatalf("failed to send decision: %v", err)
}
if decisionResp.Success {
t.Error("expected failure for human contact type")
}
if decisionResp.Error != "human contact approvals are no longer supported" {
t.Errorf("expected specific error for human contact, got: %s", decisionResp.Error)
}
// Shutdown daemon

View File

@@ -4,7 +4,8 @@ package daemon
import (
"context"
"fmt"
"encoding/json"
"net"
"path/filepath"
"testing"
"time"
@@ -14,57 +15,30 @@ import (
"github.com/humanlayer/humanlayer/hld/client"
"github.com/humanlayer/humanlayer/hld/config"
"github.com/humanlayer/humanlayer/hld/internal/testutil"
"github.com/humanlayer/humanlayer/hld/rpc"
"github.com/humanlayer/humanlayer/hld/session"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
// TestSessionStateTransitionsIntegration tests the full flow of session state changes
// when approvals are created and resolved
func TestSessionStateTransitionsIntegration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
func TestDaemonSessionStateIntegration(t *testing.T) {
// This test verifies that session state transitions work correctly with local approvals
// Create temporary socket path for test
socketPath := testutil.SocketPath(t, "session-state")
// Create test socket
socketPath := testutil.SocketPath(t, "test")
// Use a temporary database file instead of :memory: to ensure all connections
// access the same database (in-memory databases are unique per connection)
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")
// Create the store
// Create test database
dbPath := filepath.Join(t.TempDir(), "test.db")
testStore, err := store.NewSQLiteStore(dbPath)
if err != nil {
t.Fatalf("failed to create test store: %v", err)
t.Fatalf("failed to create store: %v", err)
}
defer testStore.Close()
// Set environment variables to ensure consistent test behavior
t.Setenv("HUMANLAYER_DATABASE_PATH", dbPath)
t.Setenv("HUMANLAYER_API_KEY", "test-key")
// Create event bus
eventBus := bus.NewEventBus()
// Create mock API client
mockClient := &mockSessionStateAPIClient{
functionCalls: make(map[string]*humanlayer.FunctionCall),
}
// Create real approval components
approvalStore := approval.NewMemoryStore()
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, eventBus)
// Create approval manager
approvalManager := &approval.DefaultManager{
Client: mockClient,
Store: approvalStore,
Poller: poller,
EventBus: eventBus,
ConversationStore: testStore,
}
// Create local approval manager
approvalManager := approval.NewManager(testStore, eventBus)
// Create session manager
sessionManager, err := session.NewManager(eventBus, testStore)
@@ -125,36 +99,43 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
t.Fatalf("failed to create session: %v", err)
}
// 2. Add a tool call that needs approval
toolCall := &store.ConversationEvent{
SessionID: sessionID,
// 2. Add a tool call event that would need approval
toolCallEvent := &store.ConversationEvent{
SessionID: sessionID, // Use the actual session ID from sessions table
ClaudeSessionID: claudeSessionID,
Sequence: 1,
EventType: store.EventTypeToolCall,
ToolID: "tool-001",
ToolName: "dangerous_function",
ToolID: "tool-call-123",
ToolInputJSON: `{"action": "delete_all"}`,
ApprovalStatus: "", // No approval yet
ApprovalID: "",
CreatedAt: time.Now(),
}
if err := testStore.AddConversationEvent(ctx, toolCall); err != nil {
if err := testStore.AddConversationEvent(ctx, toolCallEvent); err != nil {
t.Fatalf("failed to add tool call: %v", err)
}
// 3. Simulate an approval coming from HumanLayer API
approvalID := "approval-001"
mockClient.AddFunctionCall(humanlayer.FunctionCall{
CallID: approvalID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "dangerous_function",
Kwargs: map[string]interface{}{
"action": "delete_all",
},
},
})
// 3. Create an approval via RPC (simulating MCP creating approval)
// We need to use the rpcClient directly for this test
conn, err := net.Dial("unix", socketPath)
if err != nil {
t.Fatalf("failed to connect to daemon: %v", err)
}
defer conn.Close()
// 4. Wait for poller to pick up the approval and correlate it
time.Sleep(150 * time.Millisecond)
rpcClient := &rpcClient{conn: conn}
var createResp rpc.CreateApprovalResponse
if err := rpcClient.call("createApproval", rpc.CreateApprovalRequest{
RunID: runID,
ToolName: "dangerous_function",
ToolInput: json.RawMessage(`{"action": "delete_all"}`),
}, &createResp); err != nil {
t.Fatalf("failed to create approval: %v", err)
}
approvalID := createResp.ApprovalID
// 4. Give time for event bus to propagate and correlation to happen
time.Sleep(100 * time.Millisecond)
// 5. Check that session status changed to waiting_input
updatedSession, err := testStore.GetSession(ctx, sessionID)
@@ -165,7 +146,16 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
t.Errorf("expected session status to be waiting_input, got %s", updatedSession.Status)
}
// 6. Check that approval was correlated
// 6. Check that approval was created in database
approval, err := testStore.GetApproval(ctx, approvalID)
if err != nil {
t.Fatalf("failed to get approval: %v", err)
}
if approval.Status != store.ApprovalStatusLocalPending {
t.Errorf("expected approval status to be pending, got %s", approval.Status)
}
// 7. Check that approval was correlated with tool call
conversation, err := testStore.GetConversation(ctx, claudeSessionID)
if err != nil {
t.Fatalf("failed to get conversation: %v", err)
@@ -189,15 +179,15 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
t.Errorf("expected approval ID to be %s, got %s", approvalID, correlatedEvent.ApprovalID)
}
// 7. Approve the function call via client
if err := c.SendDecision(approvalID, "function_call", "approve", "Approved for testing"); err != nil {
// 8. Approve the function call via client
if err := c.SendDecision(approvalID, "approve", "Approved for testing"); err != nil {
t.Fatalf("failed to send approval: %v", err)
}
// Give time for status update
time.Sleep(100 * time.Millisecond)
// 8. Check that session status changed back to running
// 9. Check that session status changed back to running
finalSession, err := testStore.GetSession(ctx, sessionID)
if err != nil {
t.Fatalf("failed to get final session: %v", err)
@@ -206,7 +196,16 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
t.Errorf("expected session status to be running after approval, got %s", finalSession.Status)
}
// 9. Check that approval status was updated
// 10. Check that approval status was updated
finalApproval, err := testStore.GetApproval(ctx, approvalID)
if err != nil {
t.Fatalf("failed to get final approval: %v", err)
}
if finalApproval.Status != store.ApprovalStatusLocalApproved {
t.Errorf("expected approval status to be approved, got %s", finalApproval.Status)
}
// 11. Check that approval status was updated in conversation
finalConversation, err := testStore.GetConversation(ctx, claudeSessionID)
if err != nil {
t.Fatalf("failed to get final conversation: %v", err)
@@ -239,60 +238,3 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
t.Error("daemon did not shut down in time")
}
}
// Mock API client for session state testing
type mockSessionStateAPIClient struct {
functionCalls map[string]*humanlayer.FunctionCall
}
func (m *mockSessionStateAPIClient) AddFunctionCall(fc humanlayer.FunctionCall) {
m.functionCalls[fc.CallID] = &fc
}
func (m *mockSessionStateAPIClient) GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error) {
var result []humanlayer.FunctionCall
for _, fc := range m.functionCalls {
if fc.Status == nil || fc.Status.RespondedAt == nil {
result = append(result, *fc)
}
}
return result, nil
}
func (m *mockSessionStateAPIClient) GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error) {
return []humanlayer.HumanContact{}, nil
}
func (m *mockSessionStateAPIClient) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
if fc, ok := m.functionCalls[callID]; ok {
if fc.Status == nil {
fc.Status = &humanlayer.FunctionCallStatus{}
}
now := humanlayer.CustomTime{Time: time.Now()}
fc.Status.RespondedAt = &now
approved := true
fc.Status.Approved = &approved
fc.Status.Comment = comment
return nil
}
return fmt.Errorf("function call not found: %s", callID)
}
func (m *mockSessionStateAPIClient) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
if fc, ok := m.functionCalls[callID]; ok {
if fc.Status == nil {
fc.Status = &humanlayer.FunctionCallStatus{}
}
now := humanlayer.CustomTime{Time: time.Now()}
fc.Status.RespondedAt = &now
approved := false
fc.Status.Approved = &approved
fc.Status.Comment = reason
return nil
}
return fmt.Errorf("function call not found: %s", callID)
}
func (m *mockSessionStateAPIClient) RespondToHumanContact(ctx context.Context, callID string, response string) error {
return fmt.Errorf("not implemented")
}

View File

@@ -7,15 +7,16 @@ import (
"github.com/humanlayer/humanlayer/hld/approval"
"github.com/humanlayer/humanlayer/hld/session"
"github.com/humanlayer/humanlayer/hld/store"
)
// ApprovalHandlers provides RPC handlers for approval management
// ApprovalHandlers provides RPC handlers for local approval management
type ApprovalHandlers struct {
approvals approval.Manager
sessions session.SessionManager
}
// NewApprovalHandlers creates new approval RPC handlers
// NewApprovalHandlers creates new local approval RPC handlers
func NewApprovalHandlers(approvals approval.Manager, sessions session.SessionManager) *ApprovalHandlers {
return &ApprovalHandlers{
approvals: approvals,
@@ -23,6 +24,47 @@ func NewApprovalHandlers(approvals approval.Manager, sessions session.SessionMan
}
}
// CreateApprovalRequest is the request for creating a local approval
type CreateApprovalRequest struct {
RunID string `json:"run_id"`
ToolName string `json:"tool_name"`
ToolInput json.RawMessage `json:"tool_input"`
}
// CreateApprovalResponse is the response for creating a local approval
type CreateApprovalResponse struct {
ApprovalID string `json:"approval_id"`
}
// HandleCreateApproval handles the CreateApproval RPC method
func (h *ApprovalHandlers) HandleCreateApproval(ctx context.Context, params json.RawMessage) (interface{}, error) {
var req CreateApprovalRequest
if err := json.Unmarshal(params, &req); err != nil {
return nil, fmt.Errorf("invalid request: %w", err)
}
// Validate required fields
if req.RunID == "" {
return nil, fmt.Errorf("run_id is required")
}
if req.ToolName == "" {
return nil, fmt.Errorf("tool_name is required")
}
if req.ToolInput == nil {
return nil, fmt.Errorf("tool_input is required")
}
// Create the approval
approvalID, err := h.approvals.CreateApproval(ctx, req.RunID, req.ToolName, req.ToolInput)
if err != nil {
return nil, fmt.Errorf("failed to create approval: %w", err)
}
return &CreateApprovalResponse{
ApprovalID: approvalID,
}, nil
}
// FetchApprovalsRequest is the request for fetching approvals
type FetchApprovalsRequest struct {
SessionID string `json:"session_id,omitempty"` // Optional filter by session
@@ -30,16 +72,11 @@ type FetchApprovalsRequest struct {
// FetchApprovalsResponse is the response for fetching approvals
type FetchApprovalsResponse struct {
Approvals []approval.PendingApproval `json:"approvals"`
Approvals []*store.Approval `json:"approvals"`
}
// HandleFetchApprovals handles the FetchApprovals RPC method
func (h *ApprovalHandlers) HandleFetchApprovals(ctx context.Context, params json.RawMessage) (interface{}, error) {
// Check if approval manager is configured
if h.approvals == nil {
return nil, fmt.Errorf("approval features not available: no API key configured")
}
var req FetchApprovalsRequest
if params != nil {
if err := json.Unmarshal(params, &req); err != nil {
@@ -47,27 +84,17 @@ func (h *ApprovalHandlers) HandleFetchApprovals(ctx context.Context, params json
}
}
var approvals []approval.PendingApproval
var err error
// If no session ID provided, return empty list
if req.SessionID == "" {
return &FetchApprovalsResponse{
Approvals: []*store.Approval{},
}, nil
}
if req.SessionID != "" {
// Get the session to find its run_id
sessionInfo, err := h.sessions.GetSessionInfo(req.SessionID)
if err != nil {
return nil, fmt.Errorf("session not found: %w", err)
}
// Get approvals by run_id
approvals, err = h.approvals.GetPendingApprovalsByRunID(sessionInfo.RunID)
if err != nil {
return nil, fmt.Errorf("failed to fetch approvals: %w", err)
}
} else {
// Get all pending approvals
approvals, err = h.approvals.GetPendingApprovals("")
if err != nil {
return nil, fmt.Errorf("failed to fetch approvals: %w", err)
}
// Get approvals for the session
approvals, err := h.approvals.GetPendingApprovals(ctx, req.SessionID)
if err != nil {
return nil, fmt.Errorf("failed to fetch approvals: %w", err)
}
return &FetchApprovalsResponse{
@@ -77,9 +104,9 @@ func (h *ApprovalHandlers) HandleFetchApprovals(ctx context.Context, params json
// SendDecisionRequest is the request for sending a decision
type SendDecisionRequest struct {
CallID string `json:"call_id"`
Type string `json:"type"` // "function_call" or "human_contact"
Decision string `json:"decision"` // "approve", "deny", or "respond"
CallID string `json:"call_id"` // Actually approval ID, but keeping name for compatibility
Type string `json:"type"` // Ignored for local approvals
Decision string `json:"decision"` // "approve" or "deny"
Comment string `json:"comment,omitempty"`
}
@@ -91,11 +118,6 @@ type SendDecisionResponse struct {
// HandleSendDecision handles the SendDecision RPC method
func (h *ApprovalHandlers) HandleSendDecision(ctx context.Context, params json.RawMessage) (interface{}, error) {
// Check if approval manager is configured
if h.approvals == nil {
return nil, fmt.Errorf("approval features not available: no API key configured")
}
var req SendDecisionRequest
if err := json.Unmarshal(params, &req); err != nil {
return nil, fmt.Errorf("invalid request: %w", err)
@@ -105,38 +127,30 @@ func (h *ApprovalHandlers) HandleSendDecision(ctx context.Context, params json.R
if req.CallID == "" {
return nil, fmt.Errorf("call_id is required")
}
if req.Type == "" {
return nil, fmt.Errorf("type is required")
}
if req.Decision == "" {
return nil, fmt.Errorf("decision is required")
}
// Check if this is a human contact type (no longer supported)
if req.Type == "human_contact" {
return &SendDecisionResponse{
Success: false,
Error: "human contact approvals are no longer supported",
}, nil
}
var err error
switch req.Type {
case "function_call":
switch req.Decision {
case "approve":
err = h.approvals.ApproveFunctionCall(ctx, req.CallID, req.Comment)
case "deny":
if req.Comment == "" {
return nil, fmt.Errorf("comment is required for denial")
}
err = h.approvals.DenyFunctionCall(ctx, req.CallID, req.Comment)
default:
return nil, fmt.Errorf("invalid decision for function_call: %s", req.Decision)
}
case "human_contact":
if req.Decision != "respond" {
return nil, fmt.Errorf("invalid decision for human_contact: %s", req.Decision)
}
switch req.Decision {
case "approve":
err = h.approvals.ApproveToolCall(ctx, req.CallID, req.Comment)
case "deny":
if req.Comment == "" {
return nil, fmt.Errorf("comment is required for human contact response")
return nil, fmt.Errorf("comment is required for denial")
}
err = h.approvals.RespondToHumanContact(ctx, req.CallID, req.Comment)
err = h.approvals.DenyToolCall(ctx, req.CallID, req.Comment)
default:
return nil, fmt.Errorf("invalid type: %s", req.Type)
return nil, fmt.Errorf("invalid decision: %s (must be 'approve' or 'deny')", req.Decision)
}
if err != nil {
@@ -151,8 +165,43 @@ func (h *ApprovalHandlers) HandleSendDecision(ctx context.Context, params json.R
}, nil
}
// Register registers all approval handlers with the RPC server
// GetApprovalRequest is the request for getting a specific approval
type GetApprovalRequest struct {
ApprovalID string `json:"approval_id"`
}
// GetApprovalResponse is the response for getting a specific approval
type GetApprovalResponse struct {
Approval *store.Approval `json:"approval"`
}
// HandleGetApproval handles the GetApproval RPC method
func (h *ApprovalHandlers) HandleGetApproval(ctx context.Context, params json.RawMessage) (interface{}, error) {
var req GetApprovalRequest
if err := json.Unmarshal(params, &req); err != nil {
return nil, fmt.Errorf("invalid request: %w", err)
}
// Validate required fields
if req.ApprovalID == "" {
return nil, fmt.Errorf("approval_id is required")
}
// Get the approval
approval, err := h.approvals.GetApproval(ctx, req.ApprovalID)
if err != nil {
return nil, fmt.Errorf("failed to get approval: %w", err)
}
return &GetApprovalResponse{
Approval: approval,
}, nil
}
// Register registers all local approval handlers with the RPC server
func (h *ApprovalHandlers) Register(server *Server) {
server.Register("createApproval", h.HandleCreateApproval)
server.Register("fetchApprovals", h.HandleFetchApprovals)
server.Register("getApproval", h.HandleGetApproval)
server.Register("sendDecision", h.HandleSendDecision)
}

View File

@@ -173,6 +173,28 @@ func (s *SQLiteStore) initSchema() error {
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
description TEXT
);
-- Approvals table for local approvals
CREATE TABLE IF NOT EXISTS approvals (
id TEXT PRIMARY KEY,
run_id TEXT NOT NULL,
session_id TEXT NOT NULL,
status TEXT NOT NULL CHECK (status IN ('pending', 'approved', 'denied')),
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
responded_at DATETIME,
-- Tool approval fields
tool_name TEXT NOT NULL,
tool_input TEXT NOT NULL, -- JSON
-- Response fields
comment TEXT, -- For denial reasons or approval notes
FOREIGN KEY (session_id) REFERENCES sessions(id)
);
CREATE INDEX IF NOT EXISTS idx_approvals_pending ON approvals(status) WHERE status = 'pending';
CREATE INDEX IF NOT EXISTS idx_approvals_session ON approvals(session_id);
CREATE INDEX IF NOT EXISTS idx_approvals_run_id ON approvals(run_id);
`
if _, err := s.db.Exec(schema); err != nil {
@@ -258,6 +280,49 @@ func (s *SQLiteStore) applyMigrations() error {
slog.Info("Migration 3 applied successfully")
}
// Migration 4: Add approvals table for local approvals
if currentVersion < 4 {
slog.Info("Applying migration 4: Add approvals table for local approvals")
// Create the approvals table
_, err = s.db.Exec(`
CREATE TABLE IF NOT EXISTS approvals (
id TEXT PRIMARY KEY,
run_id TEXT NOT NULL,
session_id TEXT NOT NULL,
status TEXT NOT NULL CHECK (status IN ('pending', 'approved', 'denied')),
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
responded_at DATETIME,
-- Tool approval fields
tool_name TEXT NOT NULL,
tool_input TEXT NOT NULL, -- JSON
-- Response fields
comment TEXT, -- For denial reasons or approval notes
FOREIGN KEY (session_id) REFERENCES sessions(id)
);
CREATE INDEX IF NOT EXISTS idx_approvals_pending ON approvals(status) WHERE status = 'pending';
CREATE INDEX IF NOT EXISTS idx_approvals_session ON approvals(session_id);
CREATE INDEX IF NOT EXISTS idx_approvals_run_id ON approvals(run_id);
`)
if err != nil {
return fmt.Errorf("failed to create approvals table: %w", err)
}
// Record migration
_, err = s.db.Exec(`
INSERT INTO schema_version (version, description)
VALUES (4, 'Add approvals table for local approvals')
`)
if err != nil {
return fmt.Errorf("failed to record migration 4: %w", err)
}
slog.Info("Migration 4 applied successfully")
}
return nil
}
@@ -1104,6 +1169,154 @@ func (s *SQLiteStore) StoreRawEvent(ctx context.Context, sessionID string, event
return nil
}
// CreateApproval creates a new approval
func (s *SQLiteStore) CreateApproval(ctx context.Context, approval *Approval) error {
// Validate status
if !approval.Status.IsValid() {
return fmt.Errorf("invalid approval status: %s", approval.Status)
}
query := `
INSERT INTO approvals (
id, run_id, session_id, status, created_at,
tool_name, tool_input
) VALUES (?, ?, ?, ?, ?, ?, ?)
`
_, err := s.db.ExecContext(ctx, query,
approval.ID, approval.RunID, approval.SessionID, approval.Status.String(), approval.CreatedAt,
approval.ToolName, string(approval.ToolInput),
)
if err != nil {
return fmt.Errorf("failed to create approval: %w", err)
}
return nil
}
// GetApproval retrieves an approval by ID
func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, error) {
query := `
SELECT id, run_id, session_id, status, created_at, responded_at,
tool_name, tool_input, comment
FROM approvals WHERE id = ?
`
var approval Approval
var respondedAt sql.NullTime
var comment sql.NullString
var statusStr string
var toolInputStr string
err := s.db.QueryRowContext(ctx, query, id).Scan(
&approval.ID, &approval.RunID, &approval.SessionID, &statusStr,
&approval.CreatedAt, &respondedAt,
&approval.ToolName, &toolInputStr, &comment,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("approval not found: %s", id)
}
if err != nil {
return nil, fmt.Errorf("failed to get approval: %w", err)
}
// Convert status string to ApprovalStatus
approval.Status = ApprovalStatus(statusStr)
if !approval.Status.IsValid() {
return nil, fmt.Errorf("invalid approval status in database: %s", statusStr)
}
// Handle nullable fields
if respondedAt.Valid {
approval.RespondedAt = &respondedAt.Time
}
approval.Comment = comment.String
approval.ToolInput = json.RawMessage(toolInputStr)
return &approval, nil
}
// GetPendingApprovals retrieves all pending approvals for a session
func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string) ([]*Approval, error) {
query := `
SELECT id, run_id, session_id, status, created_at, responded_at,
tool_name, tool_input, comment
FROM approvals
WHERE session_id = ? AND status = ?
ORDER BY created_at ASC
`
rows, err := s.db.QueryContext(ctx, query, sessionID, ApprovalStatusLocalPending.String())
if err != nil {
return nil, fmt.Errorf("failed to get pending approvals: %w", err)
}
defer func() { _ = rows.Close() }()
var approvals []*Approval
for rows.Next() {
var approval Approval
var respondedAt sql.NullTime
var comment sql.NullString
var statusStr string
var toolInputStr string
err := rows.Scan(
&approval.ID, &approval.RunID, &approval.SessionID, &statusStr,
&approval.CreatedAt, &respondedAt,
&approval.ToolName, &toolInputStr, &comment,
)
if err != nil {
return nil, fmt.Errorf("failed to scan approval: %w", err)
}
// Convert status string to ApprovalStatus
approval.Status = ApprovalStatus(statusStr)
if !approval.Status.IsValid() {
return nil, fmt.Errorf("invalid approval status in database: %s", statusStr)
}
// Handle nullable fields
if respondedAt.Valid {
approval.RespondedAt = &respondedAt.Time
}
approval.Comment = comment.String
approval.ToolInput = json.RawMessage(toolInputStr)
approvals = append(approvals, &approval)
}
return approvals, nil
}
// UpdateApprovalResponse updates the status and comment of an approval
func (s *SQLiteStore) UpdateApprovalResponse(ctx context.Context, id string, status ApprovalStatus, comment string) error {
// Validate status
if !status.IsValid() {
return fmt.Errorf("invalid approval status: %s", status)
}
query := `
UPDATE approvals
SET status = ?, comment = ?, responded_at = CURRENT_TIMESTAMP
WHERE id = ?
`
result, err := s.db.ExecContext(ctx, query, status.String(), comment, id)
if err != nil {
return fmt.Errorf("failed to update approval response: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("approval not found: %s", id)
}
return nil
}
// Helper function to convert MCP config to store format
func MCPServersFromConfig(sessionID string, config map[string]claudecode.MCPServer) ([]MCPServer, error) {
// First, collect all server names and sort them for deterministic ordering

View File

@@ -24,6 +24,7 @@ type ConversationStore interface {
// Tool call operations
GetPendingToolCall(ctx context.Context, sessionID string, toolName string) (*ConversationEvent, error)
GetUncorrelatedPendingToolCall(ctx context.Context, sessionID string, toolName string) (*ConversationEvent, error)
GetPendingToolCalls(ctx context.Context, sessionID string) ([]*ConversationEvent, error)
MarkToolCallCompleted(ctx context.Context, toolID string, sessionID string) error
CorrelateApproval(ctx context.Context, sessionID string, toolName string, approvalID string) error
@@ -37,6 +38,12 @@ type ConversationStore interface {
// Raw event storage (for debugging)
StoreRawEvent(ctx context.Context, sessionID string, eventJSON string) error
// Approval operations for local approvals
CreateApproval(ctx context.Context, approval *Approval) error
GetApproval(ctx context.Context, id string) (*Approval, error)
GetPendingApprovals(ctx context.Context, sessionID string) ([]*Approval, error)
UpdateApprovalResponse(ctx context.Context, id string, status ApprovalStatus, comment string) error
// Database lifecycle
Close() error
}
@@ -123,6 +130,44 @@ type MCPServer struct {
EnvJSON string // JSON object
}
// ApprovalStatus represents the status of an approval
type ApprovalStatus string
// Valid approval statuses
const (
ApprovalStatusLocalPending ApprovalStatus = "pending"
ApprovalStatusLocalApproved ApprovalStatus = "approved"
ApprovalStatusLocalDenied ApprovalStatus = "denied"
)
// String returns the string representation of the status
func (s ApprovalStatus) String() string {
return string(s)
}
// IsValid checks if the status is valid
func (s ApprovalStatus) IsValid() bool {
switch s {
case ApprovalStatusLocalPending, ApprovalStatusLocalApproved, ApprovalStatusLocalDenied:
return true
default:
return false
}
}
// Approval represents a local approval request
type Approval struct {
ID string `json:"id"`
RunID string `json:"run_id"`
SessionID string `json:"session_id"`
Status ApprovalStatus `json:"status"`
CreatedAt time.Time `json:"created_at"`
RespondedAt *time.Time `json:"responded_at,omitempty"`
ToolName string `json:"tool_name"`
ToolInput json.RawMessage `json:"tool_input"`
Comment string `json:"comment,omitempty"`
}
// EventType constants
const (
EventTypeMessage = "message"

22
hlyr/.npmignore Normal file
View File

@@ -0,0 +1,22 @@
# Test files and directories
hack/
test-local-approvals.ts
test-*.ts
test_local_approvals.md
# Test configs
mcp-config.json
test-mcp-logging.sh
# Source files (only dist is published)
src/
tsconfig.json
.eslintrc.json
.prettierrc
# Development files
*.log
.DS_Store
node_modules/
coverage/
.vscode/

View File

@@ -245,6 +245,10 @@ npm test
npm run dev
```
### Testing Local Approvals
For testing the local MCP approvals system without HumanLayer API access, see [test_local_approvals.md](./test_local_approvals.md).
## License
Apache-2.0

465
hlyr/hack/test-local-approvals.ts Executable file
View File

@@ -0,0 +1,465 @@
#!/usr/bin/env bun
import { connectWithRetry, DaemonClient, Approval } from '../src/daemonClient.js'
import { homedir } from 'os'
import { join } from 'path'
import * as fs from 'fs/promises'
import { spawn } from 'child_process'
import { Database } from 'bun:sqlite'
import { parseArgs } from 'util'
// Configuration
const SOCKET_PATH = join(homedir(), '.humanlayer', 'daemon.sock')
const DB_PATH = join(homedir(), '.humanlayer', 'daemon.db')
const MCP_LOG_DIR = join(homedir(), '.humanlayer', 'logs')
// ANSI color codes for output
const colors = {
reset: '\x1b[0m',
green: '\x1b[32m',
red: '\x1b[31m',
yellow: '\x1b[33m',
blue: '\x1b[34m',
cyan: '\x1b[36m',
magenta: '\x1b[35m',
}
function log(level: 'info' | 'success' | 'error' | 'debug' | 'mcp', message: string) {
const timestamp = new Date().toISOString()
const color = {
info: colors.blue,
success: colors.green,
error: colors.red,
debug: colors.cyan,
mcp: colors.magenta,
}[level]
console.log(`${color}[${timestamp}] [${level.toUpperCase()}] ${message}${colors.reset}`)
}
// Monitor MCP logs in real-time
async function monitorMCPLogs(runId: string): Promise<() => void> {
const logPath = join(MCP_LOG_DIR, `mcp-claude-approvals-${runId}.log`)
log('info', `Monitoring MCP logs at: ${logPath}`)
// Wait for log file to exist
let attempts = 0
while (attempts < 10) {
try {
await fs.access(logPath)
break
} catch {
await new Promise(resolve => setTimeout(resolve, 500))
attempts++
}
}
// Create a tail-like process to follow the log
const tail = spawn('tail', ['-f', logPath])
tail.stdout.on('data', data => {
const lines = data.toString().split('\n').filter(Boolean)
for (const line of lines) {
try {
const entry = JSON.parse(line)
const prefix = `[${entry.level}] ${entry.message}`
const dataStr = entry.data ? ` ${JSON.stringify(entry.data)}` : ''
log('mcp', `${prefix}${dataStr}`)
} catch {
// Not JSON, just print as-is
log('mcp', line)
}
}
})
tail.stderr.on('data', data => {
log('error', `MCP log error: ${data}`)
})
// Return cleanup function
return () => {
tail.kill()
}
}
// Launch a session that will definitely trigger an approval
async function launchApprovalSession(client: DaemonClient, testFile: string) {
log('info', 'Launching session that will trigger file write approval...')
const launchRequest = {
query: `Please write the text "Hello from MCP approval test!" to a file named "${testFile}". This is a test of the approval system.`,
working_dir: process.cwd(),
mcp_config: {
mcpServers: {
approvals: {
command: 'npm',
args: ['run', 'dev', 'mcp', 'claude_approvals'],
env: {
...process.env,
HUMANLAYER_MCP_DEBUG: 'true',
},
},
},
},
permission_prompt_tool: 'mcp__approvals__request_permission',
}
const session = await client.launchSession(launchRequest)
log('success', `Session launched: ${session.session_id}`)
log('info', `Run ID: ${session.run_id}`)
return session
}
// Launch a session for interactive monitoring
async function launchInteractiveSession(client: DaemonClient, query?: string) {
// Generate random content to ensure approval is always triggered
const timestamp = new Date().toISOString()
const randomId = Math.random().toString(36).substring(7)
const defaultQuery = `Please write "Hello from HumanLayer MCP test!\nTimestamp: ${timestamp}\nTest ID: ${randomId}" to a file named blah.txt`
const userQuery = query || defaultQuery
log('info', 'Launching interactive session...')
log('info', `Query: ${userQuery}`)
const launchRequest = {
query: userQuery,
working_dir: process.cwd(),
mcp_config: {
mcpServers: {
approvals: {
command: 'npm',
args: ['run', 'dev', 'mcp', 'claude_approvals'],
env: {
...process.env,
HUMANLAYER_MCP_DEBUG: 'true',
},
},
},
},
permission_prompt_tool: 'mcp__approvals__request_permission',
}
const session = await client.launchSession(launchRequest)
log('success', `Session launched: ${session.session_id}`)
log('info', `Run ID: ${session.run_id}`)
return session
}
// Automated test mode
async function runAutomatedTest() {
log('info', '=== Automated MCP Approval Test ===\n')
// Enable MCP debug logging
process.env.HUMANLAYER_MCP_DEBUG = 'true'
// Connect to daemon
const client = await connectWithRetry(SOCKET_PATH, 3, 1000)
log('success', 'Connected to daemon')
try {
// Generate test file name
const testFile = `test-mcp-approval-${Date.now()}.txt`
log('info', `Test file: ${testFile}`)
// Launch session
const session = await launchApprovalSession(client, testFile)
// Start monitoring MCP logs
const stopMonitoring = await monitorMCPLogs(session.run_id)
// Subscribe to approval events
const eventEmitter = await client.subscribe({
event_types: ['new_approval', 'approval_resolved'],
session_id: session.session_id,
})
let approvalReceived = false
let approvalId: string | null = null
eventEmitter.on('event', event => {
if (event.type === 'new_approval') {
log('success', `New approval event received!`)
log('debug', `Event data: ${JSON.stringify(event.data)}`)
approvalReceived = true
} else if (event.type === 'approval_resolved') {
log('success', `Approval resolved: ${JSON.stringify(event.data)}`)
}
})
// Wait for approval to be created
log('info', 'Waiting for Claude to request file write approval...')
const maxWait = 30000 // 30 seconds
const startTime = Date.now()
while (!approvalId && Date.now() - startTime < maxWait) {
// Check database for approvals using Bun's native SQLite
const db = new Database(DB_PATH, { readonly: true })
const stmt = db.prepare('SELECT * FROM approvals WHERE session_id = ? AND status = "pending"')
const approvals = stmt.all(session.session_id) as Approval[]
if (approvals.length > 0) {
const approval = approvals[0]
approvalId = approval.id
log('success', `Approval found in database: ${approvalId}`)
log('debug', `Tool: ${approval.tool_name}`)
log('debug', `Input: ${approval.tool_input}`)
// Auto-approve after a short delay
log('info', 'Auto-approving in 2 seconds...')
await new Promise(resolve => setTimeout(resolve, 2000))
try {
await client.sendDecision(approvalId, 'approve', 'Automated test approval')
log('success', '✓ Approval sent successfully')
// Wait for file to be created
await new Promise(resolve => setTimeout(resolve, 3000))
// Check if file was created
try {
await fs.access(testFile)
log('success', `✓ File "${testFile}" was created successfully`)
// Read and display content
const content = await fs.readFile(testFile, 'utf-8')
log('info', `File content: ${content}`)
// Clean up test file
await fs.unlink(testFile)
log('info', 'Test file cleaned up')
} catch {
log('error', 'File was not created - approval may have failed')
}
} catch (error) {
log('error', `Failed to approve: ${error}`)
}
}
db.close()
if (!approvalId) {
await new Promise(resolve => setTimeout(resolve, 1000))
}
}
if (!approvalId) {
log('error', 'No approval was requested within timeout period')
}
// Analyze MCP logs
const mcpLogPath = join(MCP_LOG_DIR, `mcp-claude-approvals-${session.run_id}.log`)
try {
const logs = await fs.readFile(mcpLogPath, 'utf-8')
const logLines = logs.split('\n').filter(Boolean)
log('info', `\nMCP Log Summary:`)
log('info', `Total entries: ${logLines.length}`)
const errors = logLines.filter(l => l.includes('"level":"ERROR"'))
if (errors.length > 0) {
log('error', `Found ${errors.length} errors in MCP logs`)
} else {
log('success', '✓ No errors in MCP logs')
}
} catch (error) {
log('error', `Could not analyze MCP logs: ${error}`)
}
// Summary
log('info', '\n=== Test Summary ===')
log('success', '✓ Session launched with MCP approvals')
log('success', '✓ MCP logs monitored successfully')
if (approvalReceived) {
log('success', '✓ Approval event received via subscription')
}
if (approvalId) {
log('success', '✓ Approval created and processed')
}
stopMonitoring()
} finally {
client.close()
// Exit cleanly after test completes
process.exit(0)
}
}
// Interactive monitoring mode
async function runInteractiveMode(query?: string) {
log('info', '=== Interactive MCP Monitoring Mode ===\n')
// Enable MCP debug logging
process.env.HUMANLAYER_MCP_DEBUG = 'true'
// Connect to daemon
const client = await connectWithRetry(SOCKET_PATH, 3, 1000)
log('success', 'Connected to daemon')
try {
// Launch session
const session = await launchInteractiveSession(client, query)
// Start monitoring MCP logs
const stopMonitoring = await monitorMCPLogs(session.run_id)
// Subscribe to approval events
const eventEmitter = await client.subscribe({
event_types: ['new_approval', 'approval_resolved', 'session_status_changed'],
session_id: session.session_id,
})
let sessionCompleted = false
eventEmitter.on('event', event => {
if (event.type === 'new_approval') {
log('success', '\n🔔 NEW APPROVAL REQUEST!')
log('info', `Approval ID: ${event.data.approval_id}`)
log('info', `Tool: ${event.data.tool_name || 'N/A'}`)
log('info', '\nYou can approve/deny this in:')
log('info', ' - TUI: npx humanlayer tui')
log('info', ' - WUI: Open the desktop app')
log('info', ` - Session URL: #/sessions/${session.session_id}\n`)
} else if (event.type === 'approval_resolved') {
const approved = event.data.approved ? 'approved' : 'denied'
log('success', `✓ Approval ${approved}: ${event.data.response_text || 'No comment'}`)
} else if (event.type === 'session_status_changed') {
const status = event.data.new_status || event.data.status || 'unknown'
log('info', `Session status: ${status}`)
// Exit when session completes or fails
if (status === 'completed' || status === 'failed') {
sessionCompleted = true
}
}
})
log('info', '\n📋 Session Information:')
log('info', `Session ID: ${session.session_id}`)
log('info', `Run ID: ${session.run_id}`)
log('info', '\n🛠 Manage approvals:')
log('info', ' TUI: npx humanlayer tui')
log('info', ' WUI: Open the HumanLayer desktop app')
log('info', '\n📊 Monitoring MCP logs...')
log('info', 'Press Ctrl+C to stop monitoring\n')
// Keep running until session completes or interrupted
await new Promise(resolve => {
const sigintHandler = () => {
log('info', '\nStopping monitor...')
process.removeListener('SIGINT', sigintHandler)
resolve(undefined)
}
process.on('SIGINT', sigintHandler)
// Check periodically if session has completed
const checkInterval = setInterval(() => {
if (sessionCompleted) {
clearInterval(checkInterval)
process.removeListener('SIGINT', sigintHandler)
log('info', '\n✨ Session completed, exiting...')
resolve(undefined)
}
}, 100)
})
stopMonitoring()
} finally {
client.close()
process.exit(0)
}
}
// Main entry point
async function main() {
// Parse command line arguments
const { values } = parseArgs({
args: process.argv.slice(2),
options: {
test: {
type: 'boolean',
short: 't',
default: false,
},
interactive: {
type: 'boolean',
short: 'i',
default: false,
},
query: {
type: 'string',
short: 'q',
},
},
})
// Check if hlyr is built (we're in the hlyr/hack directory)
const hlyrDistPath = join(__dirname, '..', 'dist', 'index.js')
try {
await fs.access(hlyrDistPath)
} catch {
log('error', 'hlyr is not built. Please run: cd .. && npm install && npm run build')
process.exit(1)
}
// Ensure MCP log directory exists
await fs.mkdir(MCP_LOG_DIR, { recursive: true })
try {
if (values.test) {
await runAutomatedTest()
} else if (values.interactive || !values.test) {
await runInteractiveMode(values.query)
}
} catch (error) {
log('error', `Error: ${error}`)
process.exit(1)
}
}
// Show usage if needed
if (process.argv.includes('--help') || process.argv.includes('-h')) {
console.log(`
Local MCP Approvals Test Tool
Usage:
bun test-local-approvals.ts [options]
Options:
-t, --test Run automated test (launches session, triggers approval, auto-approves)
-i, --interactive Run in interactive mode (launches session, monitors logs, manual approval)
-q, --query Custom query for the session (interactive mode only)
-h, --help Show this help message
Examples:
# Run automated test
bun test-local-approvals.ts --test
# Interactive mode (default - will request to write to blah.txt)
bun test-local-approvals.ts
# Interactive mode with custom query
bun test-local-approvals.ts -q "Help me analyze this codebase"
# Interactive mode without triggering approval
bun test-local-approvals.ts -q "Hello, how are you?"
Notes:
- Run from the hlyr/hack directory
- Make sure the daemon is running: cd ../../hld && ./dist/bin/hld -debug
- Build hlyr first: cd .. && npm install && npm run build
- In interactive mode, use TUI or WUI to approve/deny
- MCP logs are saved to: ~/.humanlayer/logs/
`)
process.exit(0)
}
// Run main
main().catch(error => {
console.error('Fatal error:', error)
process.exit(1)
})

View File

@@ -40,12 +40,30 @@ interface Event {
type: 'new_approval' | 'approval_resolved' | 'session_status_changed'
timestamp: string
data: {
type?: 'function_call' | 'human_contact'
count?: number
// Common fields
session_id?: string
run_id?: string
// new_approval event fields
approval_id?: string
tool_name?: string
// approval_resolved event fields
approved?: boolean
response_text?: string
// session_status_changed event fields
old_status?: string
new_status?: string
parent_session_id?: string
// Legacy fields (may not be used)
type?: 'function_call' | 'human_contact'
count?: number
function_name?: string
message?: string
// Allow other fields
[key: string]: string | number | boolean | undefined
}
}
@@ -54,6 +72,18 @@ interface EventNotification {
event: Event
}
export interface Approval {
id: string
run_id: string
session_id: string
status: 'pending' | 'approved' | 'denied'
created_at: string
responded_at?: string
tool_name: string
tool_input: unknown
comment?: string
}
interface LaunchSessionRequest {
query: string
model?: string
@@ -316,15 +346,32 @@ export class DaemonClient extends EventEmitter {
return this.call<{ sessions: unknown[] }>('listSessions')
}
async fetchApprovals(sessionId: string): Promise<unknown[]> {
const resp = await this.call<{ approvals: unknown[] }>('fetchApprovals', { session_id: sessionId })
async createApproval(
runId: string,
toolName: string,
toolInput: unknown,
): Promise<{ approval_id: string }> {
return this.call<{ approval_id: string }>('createApproval', {
run_id: runId,
tool_name: toolName,
tool_input: toolInput,
})
}
async fetchApprovals(sessionId: string): Promise<Approval[]> {
const resp = await this.call<{ approvals: Approval[] }>('fetchApprovals', { session_id: sessionId })
return resp.approvals
}
async sendDecision(callId: string, type: string, decision: string, comment: string): Promise<void> {
async getApproval(approvalId: string): Promise<Approval> {
const resp = await this.call<{ approval: Approval }>('getApproval', { approval_id: approvalId })
return resp.approval
}
async sendDecision(approvalId: string, decision: string, comment: string): Promise<void> {
const resp = await this.call<{ success: boolean; error?: string }>('sendDecision', {
call_id: callId,
type,
call_id: approvalId, // Using call_id for backward compatibility
type: 'function_call', // Always function_call for local approvals
decision,
comment,
})

View File

@@ -8,6 +8,8 @@ import {
} from '@modelcontextprotocol/sdk/types.js'
import { humanlayer } from '@humanlayer/sdk'
import { resolveFullConfig } from './config.js'
import { DaemonClient } from './daemonClient.js'
import { logger } from './mcpLogger.js'
function validateAuth(): void {
const config = resolveFullConfig({})
@@ -95,13 +97,16 @@ export async function startDefaultMCPServer() {
/**
* Start the Claude approvals MCP server that provides request_permission functionality
* Returns responses in the format required by Claude Code SDK
*
* This now uses local approvals through the daemon instead of HumanLayer API
*/
export async function startClaudeApprovalsMCPServer() {
validateAuth()
// No auth validation needed - uses local daemon
logger.info('Starting Claude approvals MCP server')
const server = new Server(
{
name: 'humanlayer-claude-approvals',
name: 'humanlayer-claude-local-approvals',
version: '1.0.0',
},
{
@@ -111,16 +116,8 @@ export async function startClaudeApprovalsMCPServer() {
},
)
const resolvedConfig = resolveFullConfig({})
const hl = humanlayer({
apiKey: resolvedConfig.api_key,
...(resolvedConfig.api_base_url && { apiBaseUrl: resolvedConfig.api_base_url }),
...(resolvedConfig.run_id && { runId: resolvedConfig.run_id }),
...(Object.keys(resolvedConfig.contact_channel).length > 0 && {
contactChannel: resolvedConfig.contact_channel,
}),
})
// Create daemon client
const daemonClient = new DaemonClient()
server.setRequestHandler(ListToolsRequestSchema, async () => {
return {
@@ -142,56 +139,114 @@ export async function startClaudeApprovalsMCPServer() {
})
server.setRequestHandler(CallToolRequestSchema, async request => {
/**
* example input
* {
* "tool_name": "Write",
* "input": {
* "file_name": "hello.txt"
* "content": "Hello, how are you?"
* }
* }
*/
logger.debug('Received tool call request', { name: request.params.name })
if (request.params.name === 'request_permission') {
const toolName: string | undefined = request.params.arguments?.tool_name
if (!toolName) {
logger.error('Invalid tool name in request_permission', request.params.arguments)
throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name requesting permissions')
}
const input: Record<string, unknown> = request.params.arguments?.input || {}
const approvalResult = await hl.fetchHumanApproval({
spec: {
fn: toolName,
kwargs: input,
},
})
// Get run ID from environment (set by Claude Code)
const runId = process.env.HUMANLAYER_RUN_ID
if (!runId) {
logger.error('HUMANLAYER_RUN_ID not set in environment')
throw new McpError(ErrorCode.InternalError, 'HUMANLAYER_RUN_ID not set')
}
if (!approvalResult.approved) {
logger.info('Processing approval request', { runId, toolName })
try {
// Connect to daemon
logger.debug('Connecting to daemon...')
await daemonClient.connect()
logger.debug('Connected to daemon')
// Create approval request
logger.debug('Creating approval request...', { runId, toolName })
const createResponse = await daemonClient.createApproval(runId, toolName, input)
const approvalId = createResponse.approval_id
logger.info('Created approval', { approvalId })
// Poll for approval status
let approved = false
let comment = ''
let polling = true
while (polling) {
try {
// Get the specific approval by ID
logger.debug('Fetching approval status...', { approvalId })
const approval = (await daemonClient.getApproval(approvalId)) as {
id: string
status: string
comment?: string
}
logger.debug('Approval status', { status: approval.status })
if (approval.status !== 'pending') {
// Approval has been resolved
approved = approval.status === 'approved'
comment = approval.comment || ''
polling = false
logger.info('Approval resolved', {
approvalId,
status: approval.status,
approved,
})
} else {
// Still pending, wait and poll again
logger.debug('Approval still pending, polling again...')
await new Promise(resolve => setTimeout(resolve, 1000))
}
} catch (error) {
logger.error('Failed to get approval status', { error, approvalId })
// Re-throw the error since this is a critical failure
throw new McpError(ErrorCode.InternalError, 'Failed to get approval status')
}
}
if (!approved) {
logger.info('Approval denied', { approvalId, comment })
return {
content: [
{
type: 'text',
text: JSON.stringify({
behavior: 'deny',
message: comment || 'Request denied by human reviewer',
}),
},
],
}
}
logger.info('Approval granted', { approvalId })
return {
content: [
{
type: 'text',
text: JSON.stringify({
behavior: 'deny',
message: approvalResult.comment || 'Request denied by human reviewer',
behavior: 'allow',
updatedInput: input,
}),
},
],
}
}
return {
content: [
{
type: 'text',
text: JSON.stringify({
behavior: 'allow',
updatedInput: input,
}),
},
],
} catch (error) {
logger.error('Failed to process approval', error)
throw new McpError(
ErrorCode.InternalError,
`Failed to process approval: ${error instanceof Error ? error.message : String(error)}`,
)
} finally {
logger.debug('Closing daemon connection')
daemonClient.close()
}
}
@@ -199,5 +254,12 @@ export async function startClaudeApprovalsMCPServer() {
})
const transport = new StdioServerTransport()
await server.connect(transport)
try {
await server.connect(transport)
logger.info('MCP server connected and ready')
} catch (error) {
logger.error('Failed to start MCP server', error)
throw error
}
}

141
hlyr/src/mcpLogger.ts Normal file
View File

@@ -0,0 +1,141 @@
import * as fs from 'fs'
import * as path from 'path'
import * as os from 'os'
interface LogEntry {
timestamp: string
level: 'DEBUG' | 'INFO' | 'WARN' | 'ERROR'
component: string
message: string
data?: unknown
}
class MCPLogger {
private logDir: string
private logPath: string
private isDebug: boolean
private writeStream?: fs.WriteStream
constructor() {
// Check if debug mode is enabled
this.isDebug = process.env.HUMANLAYER_MCP_DEBUG === 'true'
// Create log directory
this.logDir = path.join(os.homedir(), '.humanlayer', 'logs')
if (!fs.existsSync(this.logDir)) {
fs.mkdirSync(this.logDir, { recursive: true })
}
// Create log file with run ID if available, otherwise use date
const runId = process.env.HUMANLAYER_RUN_ID
const identifier = runId || new Date().toISOString().split('T')[0]
this.logPath = path.join(this.logDir, `mcp-claude-approvals-${identifier}.log`)
// Open write stream in append mode
this.writeStream = fs.createWriteStream(this.logPath, { flags: 'a' })
// Log startup
this.info('MCP Logger initialized', {
logPath: this.logPath,
debug: this.isDebug,
runId: process.env.HUMANLAYER_RUN_ID,
})
}
private write(entry: LogEntry): void {
if (!this.writeStream) return
const line = JSON.stringify(entry) + '\n'
this.writeStream.write(line)
}
private shouldLog(level: 'DEBUG' | 'INFO' | 'WARN' | 'ERROR'): boolean {
// Always log warnings and errors
if (level === 'WARN' || level === 'ERROR') return true
// Only log debug/info if debug mode is enabled
return this.isDebug
}
debug(message: string, data?: unknown): void {
if (!this.shouldLog('DEBUG')) return
this.write({
timestamp: new Date().toISOString(),
level: 'DEBUG',
component: 'mcp-claude-approvals',
message,
data,
})
}
info(message: string, data?: unknown): void {
if (!this.shouldLog('INFO')) return
this.write({
timestamp: new Date().toISOString(),
level: 'INFO',
component: 'mcp-claude-approvals',
message,
data,
})
}
warn(message: string, data?: unknown): void {
if (!this.shouldLog('WARN')) return
this.write({
timestamp: new Date().toISOString(),
level: 'WARN',
component: 'mcp-claude-approvals',
message,
data,
})
}
error(message: string, error?: unknown): void {
if (!this.shouldLog('ERROR')) return
let errorData: unknown = error
if (error instanceof Error) {
errorData = {
name: error.name,
message: error.message,
stack: error.stack,
}
}
this.write({
timestamp: new Date().toISOString(),
level: 'ERROR',
component: 'mcp-claude-approvals',
message,
data: errorData,
})
}
close(): void {
if (this.writeStream) {
this.writeStream.end()
this.writeStream = undefined
}
}
}
// Export singleton instance
export const logger = new MCPLogger()
// Handle process exit
process.on('exit', () => {
logger.close()
})
process.on('SIGINT', () => {
logger.close()
process.exit(0)
})
process.on('SIGTERM', () => {
logger.close()
process.exit(0)
})

View File

@@ -0,0 +1,125 @@
# Testing Local MCP Approvals
This guide explains how to test the local MCP approvals system without requiring HumanLayer API access.
## Overview
The `hack/test-local-approvals.ts` script provides a comprehensive testing tool for verifying that the MCP server, daemon, and approval flow are working correctly with local-only approvals.
## Prerequisites
1. Build hlyr and the daemon:
```bash
npm run build
```
2. Start the daemon with debug logging:
```bash
./dist/bin/hld -debug
```
3. Have Bun installed (for running TypeScript directly)
## Running Tests
### Automated Test Mode
Launches a Claude session, triggers a file write approval, and automatically approves it after 2 seconds:
```bash
bun hack/test-local-approvals.ts --test
```
This mode is useful for:
- CI/CD pipelines
- Quick verification that the system is working
- Debugging the approval flow
### Interactive Mode (Default)
Launches a Claude session with a query that will trigger an approval, then monitors for events:
```bash
# Default query (writes to blah.txt with random content)
bun hack/test-local-approvals.ts
# Custom query
bun hack/test-local-approvals.ts -q "Help me analyze this codebase"
# Query that won't trigger approvals
bun hack/test-local-approvals.ts -q "Hello, how are you?"
```
While running in interactive mode:
- Approval requests will be highlighted in the console
- Use TUI (`npx humanlayer tui`) or WUI to approve/deny
- Press Ctrl+C to stop monitoring
## What the Test Does
1. **Connects to the daemon** via Unix socket
2. **Launches a Claude session** with MCP approvals enabled
3. **Monitors MCP logs** in real-time at `~/.humanlayer/logs/`
4. **Subscribes to daemon events**:
- `new_approval` - When an approval is requested
- `approval_resolved` - When an approval is approved/denied
- `session_status_changed` - When session status changes
5. **In test mode**: Automatically approves after 2 seconds
6. **In interactive mode**: Waits for manual approval via TUI/WUI
## Understanding the Output
### Successful Automated Test
```
[INFO] === Automated MCP Approval Test ===
[SUCCESS] Connected to daemon
[SUCCESS] Session launched: <session-id>
[SUCCESS] New approval event received!
[SUCCESS] ✓ Approval sent successfully
[SUCCESS] ✓ File "test-mcp-approval-XXX.txt" was created successfully
[SUCCESS] ✓ No errors in MCP logs
```
### Interactive Mode Events
```
🔔 NEW APPROVAL REQUEST!
Approval ID: local-XXXX
Tool: Write
```
## Troubleshooting
### "Failed to connect to daemon"
- Ensure the daemon is running: `./dist/bin/hld -debug`
- Check the socket exists: `ls ~/.humanlayer/daemon.sock`
### "hlyr is not built"
- Run `npm run build` from the hlyr directory
### No approval triggered
- The default query includes random content to ensure uniqueness
- If using a custom query, make sure it requests an action (like writing a file)
### MCP errors in logs
- Check `~/.humanlayer/logs/mcp-claude-approvals-*.log` for details
- Ensure you're using the latest built version
## Command Reference
```bash
Options:
-t, --test Run automated test
-i, --interactive Run in interactive mode (default)
-q, --query Custom query for the session
-h, --help Show help message
```

View File

@@ -1,9 +1,9 @@
package main
import (
"encoding/json"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
@@ -131,14 +131,18 @@ func fetchRequests(daemonClient client.Client) tea.Cmd {
// Convert approvals to our Request type
for _, approval := range approvals {
if approval.Type == "function_call" && approval.FunctionCall != nil {
fc := approval.FunctionCall
// Build a message from the function name and kwargs
message := fmt.Sprintf("Call %s", fc.Spec.Fn)
if len(fc.Spec.Kwargs) > 0 {
// For local approvals, we always have tool calls
// Build a message from the tool name and input
message := fmt.Sprintf("Call %s", approval.ToolName)
// Parse tool input to extract parameters
var toolParams map[string]interface{}
if len(approval.ToolInput) > 0 {
// Try to parse as JSON to get parameters
if err := json.Unmarshal(approval.ToolInput, &toolParams); err == nil && len(toolParams) > 0 {
// Add first few parameters to message
params := []string{}
for k, v := range fc.Spec.Kwargs {
for k, v := range toolParams {
params = append(params, fmt.Sprintf("%s=%v", k, v))
if len(params) >= 2 {
break
@@ -146,62 +150,30 @@ func fetchRequests(daemonClient client.Client) tea.Cmd {
}
message += fmt.Sprintf(" with %s", strings.Join(params, ", "))
}
createdAt := time.Now() // Default to now if not available
if fc.Status != nil && fc.Status.RequestedAt != nil {
createdAt = fc.Status.RequestedAt.Time
}
req := Request{
ID: fc.CallID,
CallID: fc.CallID,
RunID: fc.RunID,
Type: ApprovalRequest,
Message: message,
Tool: fc.Spec.Fn,
Parameters: fc.Spec.Kwargs,
CreatedAt: createdAt,
}
// Enrich with session info if available
if sess, ok := sessionsByRunID[fc.RunID]; ok {
req.SessionID = sess.ID
req.SessionQuery = truncate(sess.Query, 50)
req.SessionModel = sess.Model
if req.SessionModel == "" {
req.SessionModel = "default"
}
}
allRequests = append(allRequests, req)
} else if approval.Type == "human_contact" && approval.HumanContact != nil {
hc := approval.HumanContact
createdAt := time.Now() // Default to now if not available
if hc.Status != nil && hc.Status.RequestedAt != nil {
createdAt = hc.Status.RequestedAt.Time
}
req := Request{
ID: hc.CallID,
CallID: hc.CallID,
RunID: hc.RunID,
Type: HumanContactRequest,
Message: hc.Spec.Msg,
CreatedAt: createdAt,
}
// Enrich with session info if available
if sess, ok := sessionsByRunID[hc.RunID]; ok {
req.SessionID = sess.ID
req.SessionQuery = truncate(sess.Query, 50)
req.SessionModel = sess.Model
if req.SessionModel == "" {
req.SessionModel = "default"
}
}
allRequests = append(allRequests, req)
}
req := Request{
ID: approval.ID,
CallID: approval.ID, // Use approval ID as call ID
RunID: approval.RunID,
Type: ApprovalRequest,
Message: message,
Tool: approval.ToolName,
Parameters: toolParams,
CreatedAt: approval.CreatedAt,
}
// Enrich with session info if available
if sess, ok := sessionsByRunID[approval.RunID]; ok {
req.SessionID = sess.ID
req.SessionQuery = truncate(sess.Query, 50)
req.SessionModel = sess.Model
if req.SessionModel == "" {
req.SessionModel = "default"
}
}
allRequests = append(allRequests, req)
}
return fetchRequestsMsg{requests: allRequests}
@@ -247,14 +219,17 @@ func fetchSessionApprovals(daemonClient client.Client, sessionID string) tea.Cmd
// Convert to Request type
var requests []Request
for _, approval := range approvals {
if approval.Type == "function_call" && approval.FunctionCall != nil {
fc := approval.FunctionCall
message := fmt.Sprintf("Call %s", fc.Spec.Fn)
// For local approvals, we always have tool calls
message := fmt.Sprintf("Call %s", approval.ToolName)
// Add parameters to message
if len(fc.Spec.Kwargs) > 0 {
// Parse tool input to extract parameters
var toolParams map[string]interface{}
if len(approval.ToolInput) > 0 {
// Try to parse as JSON to get parameters
if err := json.Unmarshal(approval.ToolInput, &toolParams); err == nil && len(toolParams) > 0 {
// Add first few parameters to message
params := []string{}
for k, v := range fc.Spec.Kwargs {
for k, v := range toolParams {
params = append(params, fmt.Sprintf("%s=%v", k, v))
if len(params) >= 2 {
break
@@ -262,62 +237,30 @@ func fetchSessionApprovals(daemonClient client.Client, sessionID string) tea.Cmd
}
message += fmt.Sprintf(" with %s", strings.Join(params, ", "))
}
createdAt := time.Now()
if fc.Status != nil && fc.Status.RequestedAt != nil {
createdAt = fc.Status.RequestedAt.Time
}
req := Request{
ID: fc.CallID,
CallID: fc.CallID,
RunID: fc.RunID,
Type: ApprovalRequest,
Message: message,
Tool: fc.Spec.Fn,
Parameters: fc.Spec.Kwargs,
CreatedAt: createdAt,
}
// Add session info if available
if sessionInfo != nil {
req.SessionID = sessionInfo.ID
req.SessionQuery = truncate(sessionInfo.Query, 50)
req.SessionModel = sessionInfo.Model
if req.SessionModel == "" {
req.SessionModel = "default"
}
}
requests = append(requests, req)
} else if approval.Type == "human_contact" && approval.HumanContact != nil {
hc := approval.HumanContact
createdAt := time.Now()
if hc.Status != nil && hc.Status.RequestedAt != nil {
createdAt = hc.Status.RequestedAt.Time
}
req := Request{
ID: hc.CallID,
CallID: hc.CallID,
RunID: hc.RunID,
Type: HumanContactRequest,
Message: hc.Spec.Msg,
CreatedAt: createdAt,
}
// Add session info if available
if sessionInfo != nil {
req.SessionID = sessionInfo.ID
req.SessionQuery = truncate(sessionInfo.Query, 50)
req.SessionModel = sessionInfo.Model
if req.SessionModel == "" {
req.SessionModel = "default"
}
}
requests = append(requests, req)
}
req := Request{
ID: approval.ID,
CallID: approval.ID, // Use approval ID as call ID
RunID: approval.RunID,
Type: ApprovalRequest,
Message: message,
Tool: approval.ToolName,
Parameters: toolParams,
CreatedAt: approval.CreatedAt,
}
// Add session info if available
if sessionInfo != nil {
req.SessionID = sessionInfo.ID
req.SessionQuery = truncate(sessionInfo.Query, 50)
req.SessionModel = sessionInfo.Model
if req.SessionModel == "" {
req.SessionModel = "default"
}
}
requests = append(requests, req)
}
return fetchSessionApprovalsMsg{approvals: requests}
@@ -386,9 +329,9 @@ func sendApproval(daemonClient client.Client, callID string, approved bool, comm
return func() tea.Msg {
var err error
if approved {
err = daemonClient.ApproveFunctionCall(callID, comment)
err = daemonClient.ApproveToolCall(callID, comment)
} else {
err = daemonClient.DenyFunctionCall(callID, comment)
err = daemonClient.DenyToolCall(callID, comment)
}
if err != nil {
@@ -401,7 +344,8 @@ func sendApproval(daemonClient client.Client, callID string, approved bool, comm
func sendHumanResponse(daemonClient client.Client, requestID string, response string) tea.Cmd {
return func() tea.Msg {
err := daemonClient.RespondToHumanContact(requestID, response)
// Human contact is no longer supported in local approvals
err := fmt.Errorf("human contact approvals are no longer supported")
if err != nil {
return humanResponseSentMsg{err: err}
}