mirror of
https://github.com/humanlayer/humanlayer.git
synced 2025-08-20 19:01:22 +03:00
Implement local-only approvals, removing HumanLayer API dependencies (#261)
* Implement local-only approvals, removing HumanLayer API dependencies Major refactoring to simplify the approval system by making it entirely local: - Add approvals table to SQLite with CRUD operations - Replace complex HumanLayer API types with simple local types - Rewrite approval manager to only handle local SQLite operations - Remove API polling, correlation, and remote approval code - Update RPC handlers to support local approvals with backward compatibility - Update MCP integration to create and poll local approvals - Update TUI to work with new simplified approval format - Fix all tests including integration tests This removes ~1800 lines of API integration code while maintaining all functionality. Approvals are now faster, simpler, and have no external dependencies. * Add getApproval RPC endpoint for local approval polling - Add GetApproval method to Manager interface - Implement GetApproval in manager to retrieve specific approvals by ID - Add HandleGetApproval RPC handler to expose endpoint - Add migration 4 to create approvals table for existing databases This allows the MCP server to poll for specific approval status instead of fetching all pending approvals repeatedly. * formatting * Update MCP server to use getApproval endpoint - Add getApproval method to daemonClient with proper TypeScript types - Refactor MCP approval polling to use single approval lookup instead of fetching all - Improve Event interface documentation with actual field names from daemon - Add MCP-specific logger for better debugging and monitoring - Simplify polling logic for cleaner, more efficient approval status checks * Add test tooling and documentation for local approvals - Add comprehensive test script for automated and interactive approval testing - Create test_local_approvals.md documentation with usage examples - Add .npmignore to exclude test files from npm package - Update README with reference to test documentation - Test script features auto-exit on session completion and random content generation - Fix event field names to match actual daemon event data structure * formatting
This commit is contained in:
@@ -61,7 +61,7 @@ status:
|
||||
# Generate mocks
|
||||
mocks:
|
||||
mockgen -source=session/types.go -destination=session/mock_session.go -package=session SessionManager
|
||||
mockgen -source=approval/types.go -destination=approval/mock_approval.go -package=approval Manager,Store,APIClient
|
||||
mockgen -source=approval/types.go -destination=approval/mock_approval.go -package=approval Manager
|
||||
mockgen -source=client/types.go -destination=client/mock_client.go -package=client Client,Factory
|
||||
mockgen -source=bus/types.go -destination=bus/mock_bus.go -package=bus EventBus
|
||||
mockgen -source=store/store.go -destination=store/mock_store.go -package=store ConversationStore
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestPoller_CorrelateApproval(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
functionCall humanlayer.FunctionCall
|
||||
setupStore func(mockStore *store.MockConversationStore)
|
||||
expectCorrelate bool
|
||||
expectStatusUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "successful_correlation",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-123",
|
||||
RunID: "run-456",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "dangerous_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// Expect GetSessionByRunID to find the session
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "run-456").
|
||||
Return(&store.Session{
|
||||
ID: "sess-789",
|
||||
RunID: "run-456",
|
||||
}, nil)
|
||||
|
||||
// Return a pending tool call
|
||||
mockStore.EXPECT().
|
||||
GetPendingToolCall(gomock.Any(), "sess-789", "dangerous_function").
|
||||
Return(&store.ConversationEvent{
|
||||
ID: 1,
|
||||
SessionID: "sess-789",
|
||||
ToolName: "dangerous_function",
|
||||
}, nil)
|
||||
|
||||
// Expect correlation
|
||||
mockStore.EXPECT().
|
||||
CorrelateApproval(gomock.Any(), "sess-789", "dangerous_function", "fc-123").
|
||||
Return(nil)
|
||||
|
||||
// Expect status update to waiting_input
|
||||
mockStore.EXPECT().
|
||||
UpdateSession(gomock.Any(), "sess-789", gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, sessionID string, update store.SessionUpdate) error {
|
||||
if update.Status == nil || *update.Status != store.SessionStatusWaitingInput {
|
||||
t.Errorf("expected status update to waiting_input, got %v", update.Status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
},
|
||||
expectCorrelate: true,
|
||||
expectStatusUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "no_matching_session",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-999",
|
||||
RunID: "unknown-run",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "some_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// GetSessionByRunID returns nil (no matching session)
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "unknown-run").
|
||||
Return(nil, nil)
|
||||
|
||||
// No further calls should happen
|
||||
},
|
||||
expectCorrelate: false,
|
||||
expectStatusUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "nil_tool_call_returned",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-nil",
|
||||
RunID: "run-nil",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "nil_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// GetSessionByRunID returns a matching session
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "run-nil").
|
||||
Return(&store.Session{
|
||||
ID: "sess-nil",
|
||||
RunID: "run-nil",
|
||||
}, nil)
|
||||
|
||||
// Return nil tool call (no error but no result)
|
||||
mockStore.EXPECT().
|
||||
GetPendingToolCall(gomock.Any(), "sess-nil", "nil_function").
|
||||
Return(nil, nil)
|
||||
|
||||
// No correlation or status update should happen
|
||||
},
|
||||
expectCorrelate: false,
|
||||
expectStatusUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "get_session_error",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-error",
|
||||
RunID: "run-error",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "error_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// GetSessionByRunID returns an error
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "run-error").
|
||||
Return(nil, fmt.Errorf("database error"))
|
||||
|
||||
// No further calls should happen
|
||||
},
|
||||
expectCorrelate: false,
|
||||
expectStatusUpdate: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock store
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
tt.setupStore(mockStore)
|
||||
|
||||
// Create poller with mock store
|
||||
poller := &Poller{
|
||||
conversationStore: mockStore,
|
||||
}
|
||||
|
||||
// Call correlateApproval
|
||||
ctx := context.Background()
|
||||
poller.correlateApproval(ctx, tt.functionCall)
|
||||
|
||||
// The expectations are verified by gomock
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,294 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
// MemoryStore is an in-memory implementation of Store
|
||||
type MemoryStore struct {
|
||||
mu sync.RWMutex
|
||||
functionCalls map[string]*humanlayer.FunctionCall // indexed by call_id
|
||||
humanContacts map[string]*humanlayer.HumanContact // indexed by call_id
|
||||
byRunID map[string][]string // run_id -> []call_id
|
||||
}
|
||||
|
||||
// NewMemoryStore creates a new in-memory store
|
||||
func NewMemoryStore() *MemoryStore {
|
||||
return &MemoryStore{
|
||||
functionCalls: make(map[string]*humanlayer.FunctionCall),
|
||||
humanContacts: make(map[string]*humanlayer.HumanContact),
|
||||
byRunID: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// StoreFunctionCall stores a function call approval request
|
||||
func (s *MemoryStore) StoreFunctionCall(fc humanlayer.FunctionCall) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Store by call_id
|
||||
s.functionCalls[fc.CallID] = &fc
|
||||
|
||||
// Index by run_id if present
|
||||
if fc.RunID != "" {
|
||||
s.addToRunIndex(fc.RunID, fc.CallID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoreHumanContact stores a human contact request
|
||||
func (s *MemoryStore) StoreHumanContact(hc humanlayer.HumanContact) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Store by call_id
|
||||
s.humanContacts[hc.CallID] = &hc
|
||||
|
||||
// Index by run_id if present
|
||||
if hc.RunID != "" {
|
||||
s.addToRunIndex(hc.RunID, hc.CallID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFunctionCall retrieves a function call by ID
|
||||
func (s *MemoryStore) GetFunctionCall(callID string) (*humanlayer.FunctionCall, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
fc, ok := s.functionCalls[callID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("function call not found: %s", callID)
|
||||
}
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
// GetHumanContact retrieves a human contact by ID
|
||||
func (s *MemoryStore) GetHumanContact(callID string) (*humanlayer.HumanContact, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
hc, ok := s.humanContacts[callID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("human contact not found: %s", callID)
|
||||
}
|
||||
return hc, nil
|
||||
}
|
||||
|
||||
// GetAllPending returns all pending approvals
|
||||
func (s *MemoryStore) GetAllPending() ([]PendingApproval, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var approvals []PendingApproval
|
||||
|
||||
// Add function calls
|
||||
for _, fc := range s.functionCalls {
|
||||
// Only include if not yet responded
|
||||
if fc.Status == nil || fc.Status.RespondedAt == nil {
|
||||
approvals = append(approvals, PendingApproval{
|
||||
Type: "function_call",
|
||||
FunctionCall: fc,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add human contacts
|
||||
for _, hc := range s.humanContacts {
|
||||
// Only include if not yet responded
|
||||
if hc.Status == nil || hc.Status.RespondedAt == nil {
|
||||
approvals = append(approvals, PendingApproval{
|
||||
Type: "human_contact",
|
||||
HumanContact: hc,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return approvals, nil
|
||||
}
|
||||
|
||||
// GetPendingByRunID returns pending approvals for a specific run_id
|
||||
func (s *MemoryStore) GetPendingByRunID(runID string) ([]PendingApproval, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
callIDs, ok := s.byRunID[runID]
|
||||
if !ok {
|
||||
return []PendingApproval{}, nil
|
||||
}
|
||||
|
||||
var approvals []PendingApproval
|
||||
for _, callID := range callIDs {
|
||||
// Check function calls
|
||||
if fc, ok := s.functionCalls[callID]; ok {
|
||||
if fc.Status == nil || fc.Status.RespondedAt == nil {
|
||||
approvals = append(approvals, PendingApproval{
|
||||
Type: "function_call",
|
||||
FunctionCall: fc,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check human contacts
|
||||
if hc, ok := s.humanContacts[callID]; ok {
|
||||
if hc.Status == nil || hc.Status.RespondedAt == nil {
|
||||
approvals = append(approvals, PendingApproval{
|
||||
Type: "human_contact",
|
||||
HumanContact: hc,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return approvals, nil
|
||||
}
|
||||
|
||||
// MarkFunctionCallResponded marks a function call as responded
|
||||
func (s *MemoryStore) MarkFunctionCallResponded(callID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
fc, ok := s.functionCalls[callID]
|
||||
if !ok {
|
||||
return fmt.Errorf("function call not found: %s", callID)
|
||||
}
|
||||
|
||||
// Update status to mark as responded
|
||||
if fc.Status == nil {
|
||||
fc.Status = &humanlayer.FunctionCallStatus{}
|
||||
}
|
||||
now := humanlayer.CustomTime{Time: time.Now()}
|
||||
fc.Status.RespondedAt = &now
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkHumanContactResponded marks a human contact as responded
|
||||
func (s *MemoryStore) MarkHumanContactResponded(callID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
hc, ok := s.humanContacts[callID]
|
||||
if !ok {
|
||||
return fmt.Errorf("human contact not found: %s", callID)
|
||||
}
|
||||
|
||||
// Update status to mark as responded
|
||||
if hc.Status == nil {
|
||||
hc.Status = &humanlayer.HumanContactStatus{}
|
||||
}
|
||||
now := humanlayer.CustomTime{Time: time.Now()}
|
||||
hc.Status.RespondedAt = &now
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllCachedFunctionCalls returns all cached function calls
|
||||
func (s *MemoryStore) GetAllCachedFunctionCalls() ([]humanlayer.FunctionCall, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
calls := make([]humanlayer.FunctionCall, 0, len(s.functionCalls))
|
||||
for _, fc := range s.functionCalls {
|
||||
calls = append(calls, *fc)
|
||||
}
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
// GetAllCachedHumanContacts returns all cached human contacts
|
||||
func (s *MemoryStore) GetAllCachedHumanContacts() ([]humanlayer.HumanContact, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
contacts := make([]humanlayer.HumanContact, 0, len(s.humanContacts))
|
||||
for _, hc := range s.humanContacts {
|
||||
contacts = append(contacts, *hc)
|
||||
}
|
||||
return contacts, nil
|
||||
}
|
||||
|
||||
// RemoveFunctionCall removes a function call from the store
|
||||
func (s *MemoryStore) RemoveFunctionCall(callID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
fc, ok := s.functionCalls[callID]
|
||||
if !ok {
|
||||
return nil // Already removed
|
||||
}
|
||||
|
||||
// Remove from main map
|
||||
delete(s.functionCalls, callID)
|
||||
|
||||
// Remove from run_id index
|
||||
if fc.RunID != "" {
|
||||
s.removeFromRunIndex(fc.RunID, callID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveHumanContact removes a human contact from the store
|
||||
func (s *MemoryStore) RemoveHumanContact(callID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
hc, ok := s.humanContacts[callID]
|
||||
if !ok {
|
||||
return nil // Already removed
|
||||
}
|
||||
|
||||
// Remove from main map
|
||||
delete(s.humanContacts, callID)
|
||||
|
||||
// Remove from run_id index
|
||||
if hc.RunID != "" {
|
||||
s.removeFromRunIndex(hc.RunID, callID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addToRunIndex adds a call_id to the run_id index
|
||||
func (s *MemoryStore) addToRunIndex(runID, callID string) {
|
||||
if _, exists := s.byRunID[runID]; !exists {
|
||||
s.byRunID[runID] = []string{}
|
||||
}
|
||||
|
||||
// Check if callID already exists to avoid duplicates
|
||||
for _, id := range s.byRunID[runID] {
|
||||
if id == callID {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.byRunID[runID] = append(s.byRunID[runID], callID)
|
||||
}
|
||||
|
||||
// removeFromRunIndex removes a call_id from the run_id index
|
||||
func (s *MemoryStore) removeFromRunIndex(runID, callID string) {
|
||||
ids, exists := s.byRunID[runID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Filter out the callID
|
||||
newIDs := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if id != callID {
|
||||
newIDs = append(newIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(newIDs) == 0 {
|
||||
delete(s.byRunID, runID)
|
||||
} else {
|
||||
s.byRunID[runID] = newIDs
|
||||
}
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
func TestMemoryStore_StoreFunctionCall(t *testing.T) {
|
||||
store := NewMemoryStore()
|
||||
|
||||
fc := humanlayer.FunctionCall{
|
||||
CallID: "test-call-1",
|
||||
RunID: "test-run-1",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
},
|
||||
}
|
||||
|
||||
// Store function call
|
||||
err := store.StoreFunctionCall(fc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to store function call: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve it
|
||||
retrieved, err := store.GetFunctionCall("test-call-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get function call: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.CallID != fc.CallID {
|
||||
t.Errorf("expected call_id %s, got %s", fc.CallID, retrieved.CallID)
|
||||
}
|
||||
if retrieved.RunID != fc.RunID {
|
||||
t.Errorf("expected run_id %s, got %s", fc.RunID, retrieved.RunID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryStore_StoreHumanContact(t *testing.T) {
|
||||
store := NewMemoryStore()
|
||||
|
||||
hc := humanlayer.HumanContact{
|
||||
CallID: "test-contact-1",
|
||||
RunID: "test-run-1",
|
||||
Spec: humanlayer.HumanContactSpec{
|
||||
Msg: "Test message",
|
||||
},
|
||||
}
|
||||
|
||||
// Store human contact
|
||||
err := store.StoreHumanContact(hc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to store human contact: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve it
|
||||
retrieved, err := store.GetHumanContact("test-contact-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get human contact: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.CallID != hc.CallID {
|
||||
t.Errorf("expected call_id %s, got %s", hc.CallID, retrieved.CallID)
|
||||
}
|
||||
if retrieved.RunID != hc.RunID {
|
||||
t.Errorf("expected run_id %s, got %s", hc.RunID, retrieved.RunID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryStore_GetAllPending(t *testing.T) {
|
||||
store := NewMemoryStore()
|
||||
|
||||
// Store some approvals
|
||||
fc1 := humanlayer.FunctionCall{
|
||||
CallID: "fc-1",
|
||||
RunID: "run-1",
|
||||
Spec: humanlayer.FunctionCallSpec{Fn: "func1"},
|
||||
}
|
||||
fc2 := humanlayer.FunctionCall{
|
||||
CallID: "fc-2",
|
||||
RunID: "run-2",
|
||||
Spec: humanlayer.FunctionCallSpec{Fn: "func2"},
|
||||
Status: &humanlayer.FunctionCallStatus{
|
||||
RespondedAt: &humanlayer.CustomTime{Time: time.Now()},
|
||||
},
|
||||
}
|
||||
hc1 := humanlayer.HumanContact{
|
||||
CallID: "hc-1",
|
||||
RunID: "run-1",
|
||||
Spec: humanlayer.HumanContactSpec{Msg: "msg1"},
|
||||
}
|
||||
|
||||
if err := store.StoreFunctionCall(fc1); err != nil {
|
||||
t.Fatalf("failed to store function call fc1: %v", err)
|
||||
}
|
||||
if err := store.StoreFunctionCall(fc2); err != nil {
|
||||
t.Fatalf("failed to store function call fc2: %v", err)
|
||||
}
|
||||
if err := store.StoreHumanContact(hc1); err != nil {
|
||||
t.Fatalf("failed to store human contact hc1: %v", err)
|
||||
}
|
||||
|
||||
// Get all pending
|
||||
pending, err := store.GetAllPending()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending: %v", err)
|
||||
}
|
||||
|
||||
// Should have 2 pending (fc1 and hc1, not fc2)
|
||||
if len(pending) != 2 {
|
||||
t.Errorf("expected 2 pending approvals, got %d", len(pending))
|
||||
}
|
||||
|
||||
// Verify the right ones are included
|
||||
foundFC1 := false
|
||||
foundHC1 := false
|
||||
for _, p := range pending {
|
||||
if p.Type == "function_call" && p.FunctionCall.CallID == "fc-1" {
|
||||
foundFC1 = true
|
||||
}
|
||||
if p.Type == "human_contact" && p.HumanContact.CallID == "hc-1" {
|
||||
foundHC1 = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFC1 {
|
||||
t.Error("expected to find fc-1 in pending approvals")
|
||||
}
|
||||
if !foundHC1 {
|
||||
t.Error("expected to find hc-1 in pending approvals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryStore_GetPendingByRunID(t *testing.T) {
|
||||
store := NewMemoryStore()
|
||||
|
||||
// Store approvals for different runs
|
||||
fc1 := humanlayer.FunctionCall{
|
||||
CallID: "fc-1",
|
||||
RunID: "run-1",
|
||||
Spec: humanlayer.FunctionCallSpec{Fn: "func1"},
|
||||
}
|
||||
fc2 := humanlayer.FunctionCall{
|
||||
CallID: "fc-2",
|
||||
RunID: "run-2",
|
||||
Spec: humanlayer.FunctionCallSpec{Fn: "func2"},
|
||||
}
|
||||
hc1 := humanlayer.HumanContact{
|
||||
CallID: "hc-1",
|
||||
RunID: "run-1",
|
||||
Spec: humanlayer.HumanContactSpec{Msg: "msg1"},
|
||||
}
|
||||
|
||||
if err := store.StoreFunctionCall(fc1); err != nil {
|
||||
t.Fatalf("failed to store function call fc1: %v", err)
|
||||
}
|
||||
if err := store.StoreFunctionCall(fc2); err != nil {
|
||||
t.Fatalf("failed to store function call fc2: %v", err)
|
||||
}
|
||||
if err := store.StoreHumanContact(hc1); err != nil {
|
||||
t.Fatalf("failed to store human contact hc1: %v", err)
|
||||
}
|
||||
|
||||
// Get pending for run-1
|
||||
pending, err := store.GetPendingByRunID("run-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending by run_id: %v", err)
|
||||
}
|
||||
|
||||
// Should have 2 approvals for run-1
|
||||
if len(pending) != 2 {
|
||||
t.Errorf("expected 2 pending approvals for run-1, got %d", len(pending))
|
||||
}
|
||||
|
||||
// Get pending for run-2
|
||||
pending, err = store.GetPendingByRunID("run-2")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending by run_id: %v", err)
|
||||
}
|
||||
|
||||
// Should have 1 approval for run-2
|
||||
if len(pending) != 1 {
|
||||
t.Errorf("expected 1 pending approval for run-2, got %d", len(pending))
|
||||
}
|
||||
|
||||
// Get pending for non-existent run
|
||||
pending, err = store.GetPendingByRunID("run-999")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending by run_id: %v", err)
|
||||
}
|
||||
|
||||
// Should have 0 approvals
|
||||
if len(pending) != 0 {
|
||||
t.Errorf("expected 0 pending approvals for non-existent run, got %d", len(pending))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryStore_MarkResponded(t *testing.T) {
|
||||
store := NewMemoryStore()
|
||||
|
||||
// Store approvals
|
||||
fc := humanlayer.FunctionCall{
|
||||
CallID: "fc-1",
|
||||
RunID: "run-1",
|
||||
Spec: humanlayer.FunctionCallSpec{Fn: "func1"},
|
||||
}
|
||||
hc := humanlayer.HumanContact{
|
||||
CallID: "hc-1",
|
||||
RunID: "run-1",
|
||||
Spec: humanlayer.HumanContactSpec{Msg: "msg1"},
|
||||
}
|
||||
|
||||
if err := store.StoreFunctionCall(fc); err != nil {
|
||||
t.Fatalf("failed to store function call: %v", err)
|
||||
}
|
||||
if err := store.StoreHumanContact(hc); err != nil {
|
||||
t.Fatalf("failed to store human contact: %v", err)
|
||||
}
|
||||
|
||||
// Mark function call as responded
|
||||
err := store.MarkFunctionCallResponded("fc-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to mark function call responded: %v", err)
|
||||
}
|
||||
|
||||
// Mark human contact as responded
|
||||
err = store.MarkHumanContactResponded("hc-1")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to mark human contact responded: %v", err)
|
||||
}
|
||||
|
||||
// Verify they're no longer pending
|
||||
pending, err := store.GetAllPending()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending: %v", err)
|
||||
}
|
||||
|
||||
if len(pending) != 0 {
|
||||
t.Errorf("expected 0 pending approvals after marking responded, got %d", len(pending))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
||||
store := NewMemoryStore()
|
||||
done := make(chan bool)
|
||||
|
||||
// Concurrent writes
|
||||
go func() {
|
||||
var errors []error
|
||||
for i := 0; i < 100; i++ {
|
||||
fc := humanlayer.FunctionCall{
|
||||
CallID: fmt.Sprintf("fc-%d", i),
|
||||
RunID: fmt.Sprintf("run-%d", i%10),
|
||||
Spec: humanlayer.FunctionCallSpec{Fn: "func"},
|
||||
}
|
||||
if err := store.StoreFunctionCall(fc); err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("StoreFunctionCall errors during concurrent access: %v", errors)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
var errors []error
|
||||
for i := 0; i < 100; i++ {
|
||||
hc := humanlayer.HumanContact{
|
||||
CallID: fmt.Sprintf("hc-%d", i),
|
||||
RunID: fmt.Sprintf("run-%d", i%10),
|
||||
Spec: humanlayer.HumanContactSpec{Msg: "msg"},
|
||||
}
|
||||
if err := store.StoreHumanContact(hc); err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("StoreHumanContact errors during concurrent access: %v", errors)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Concurrent reads
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
_, _ = store.GetAllPending()
|
||||
_, _ = store.GetPendingByRunID(fmt.Sprintf("run-%d", i%10))
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 3; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify data integrity
|
||||
pending, err := store.GetAllPending()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending: %v", err)
|
||||
}
|
||||
|
||||
// Should have 200 approvals (100 function calls + 100 human contacts)
|
||||
if len(pending) != 200 {
|
||||
t.Errorf("expected 200 pending approvals, got %d", len(pending))
|
||||
}
|
||||
}
|
||||
@@ -2,364 +2,233 @@ package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
// Config holds configuration for the approval manager
|
||||
type Config struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
PollInterval time.Duration
|
||||
MaxBackoff time.Duration
|
||||
BackoffFactor float64
|
||||
// manager manages approvals locally without HumanLayer API
|
||||
type manager struct {
|
||||
store store.ConversationStore
|
||||
eventBus bus.EventBus
|
||||
}
|
||||
|
||||
// DefaultManager is the default implementation of Manager
|
||||
type DefaultManager struct {
|
||||
Client APIClient
|
||||
Store Store
|
||||
Poller *Poller
|
||||
EventBus bus.EventBus
|
||||
ConversationStore store.ConversationStore
|
||||
// NewManager creates a new local approval manager
|
||||
func NewManager(store store.ConversationStore, eventBus bus.EventBus) Manager {
|
||||
return &manager{
|
||||
store: store,
|
||||
eventBus: eventBus,
|
||||
}
|
||||
}
|
||||
|
||||
// NewManager creates a new approval manager
|
||||
func NewManager(cfg Config, eventBus bus.EventBus, conversationStore store.ConversationStore) (Manager, error) {
|
||||
// Set defaults
|
||||
if cfg.PollInterval <= 0 {
|
||||
cfg.PollInterval = 5 * time.Second
|
||||
}
|
||||
if cfg.MaxBackoff <= 0 {
|
||||
cfg.MaxBackoff = 5 * time.Minute
|
||||
}
|
||||
if cfg.BackoffFactor <= 0 {
|
||||
cfg.BackoffFactor = 2.0
|
||||
}
|
||||
|
||||
// Create HumanLayer client
|
||||
opts := []humanlayer.ClientOption{
|
||||
humanlayer.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
|
||||
// Add base URL if provided
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, humanlayer.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
|
||||
client, err := humanlayer.NewClient(opts...)
|
||||
// CreateApproval creates a new local approval
|
||||
func (m *manager) CreateApproval(ctx context.Context, runID, toolName string, toolInput json.RawMessage) (string, error) {
|
||||
// Look up session by run_id
|
||||
session, err := m.store.GetSessionByRunID(ctx, runID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HumanLayer client: %w", err)
|
||||
return "", fmt.Errorf("failed to get session by run_id: %w", err)
|
||||
}
|
||||
if session == nil {
|
||||
return "", fmt.Errorf("session not found for run_id: %s", runID)
|
||||
}
|
||||
|
||||
// Create in-memory store
|
||||
store := NewMemoryStore()
|
||||
|
||||
// Create poller with configured interval
|
||||
poller := NewPoller(client, store, conversationStore, cfg.PollInterval, eventBus)
|
||||
poller.maxBackoff = cfg.MaxBackoff
|
||||
poller.backoffFactor = cfg.BackoffFactor
|
||||
|
||||
return &DefaultManager{
|
||||
Client: client,
|
||||
Store: store,
|
||||
Poller: poller,
|
||||
EventBus: eventBus,
|
||||
ConversationStore: conversationStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins polling for approvals
|
||||
func (m *DefaultManager) Start(ctx context.Context) error {
|
||||
return m.Poller.Start(ctx)
|
||||
}
|
||||
|
||||
// Stop stops the approval manager
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.Poller.Stop()
|
||||
}
|
||||
|
||||
// GetPendingApprovals returns all pending approvals, optionally filtered by session
|
||||
func (m *DefaultManager) GetPendingApprovals(sessionID string) ([]PendingApproval, error) {
|
||||
if sessionID != "" && m.ConversationStore != nil {
|
||||
// Look up the session to get its run_id
|
||||
ctx := context.Background()
|
||||
session, err := m.ConversationStore.GetSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
slog.Debug("session not found for approval filter", "session_id", sessionID, "error", err)
|
||||
return []PendingApproval{}, nil
|
||||
}
|
||||
// Get approvals by run_id
|
||||
return m.Store.GetPendingByRunID(session.RunID)
|
||||
// Create approval
|
||||
approval := &store.Approval{
|
||||
ID: "local-" + uuid.New().String(),
|
||||
RunID: runID,
|
||||
SessionID: session.ID,
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
CreatedAt: time.Now(),
|
||||
ToolName: toolName,
|
||||
ToolInput: toolInput,
|
||||
}
|
||||
return m.Store.GetAllPending()
|
||||
}
|
||||
|
||||
// GetPendingApprovalsByRunID returns pending approvals for a specific run_id
|
||||
func (m *DefaultManager) GetPendingApprovalsByRunID(runID string) ([]PendingApproval, error) {
|
||||
return m.Store.GetPendingByRunID(runID)
|
||||
}
|
||||
|
||||
// ErrAlreadyResponded is returned when an approval has already been responded to
|
||||
var ErrAlreadyResponded = errors.New("this approval has already been responded to")
|
||||
|
||||
// handleConflictError checks if an error is a conflict and handles it appropriately
|
||||
func (m *DefaultManager) handleConflictError(ctx context.Context, err error, callID string, approvalType string) error {
|
||||
var apiErr *humanlayer.APIError
|
||||
if errors.As(err, &apiErr) && apiErr.IsConflict() {
|
||||
slog.Info("approval already responded externally",
|
||||
"type", approvalType,
|
||||
"call_id", callID,
|
||||
"error", apiErr.Body)
|
||||
|
||||
// Remove from local cache
|
||||
if approvalType == "function_call" {
|
||||
if err := m.Store.RemoveFunctionCall(callID); err != nil {
|
||||
slog.Error("failed to remove function call from cache", "call_id", callID, "error", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.Store.RemoveHumanContact(callID); err != nil {
|
||||
slog.Error("failed to remove human contact from cache", "call_id", callID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update database status if we have a conversation store
|
||||
if m.ConversationStore != nil {
|
||||
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusResolved); err != nil {
|
||||
slog.Error("failed to update approval status in database", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Return a specific error to indicate it was already responded
|
||||
return ErrAlreadyResponded
|
||||
// Store it
|
||||
if err := m.store.CreateApproval(ctx, approval); err != nil {
|
||||
return "", fmt.Errorf("failed to store approval: %w", err)
|
||||
}
|
||||
return err
|
||||
|
||||
// Try to correlate with the most recent uncorrelated tool call
|
||||
if err := m.correlateApproval(ctx, approval); err != nil {
|
||||
// Log but don't fail - correlation is best effort
|
||||
slog.Warn("failed to correlate approval with tool call",
|
||||
"error", err,
|
||||
"approval_id", approval.ID,
|
||||
"session_id", session.ID)
|
||||
}
|
||||
|
||||
// Publish event for real-time updates
|
||||
m.publishNewApprovalEvent(approval)
|
||||
|
||||
// Update session status to waiting_input
|
||||
if err := m.updateSessionStatus(ctx, session.ID, store.SessionStatusWaitingInput); err != nil {
|
||||
slog.Warn("failed to update session status",
|
||||
"error", err,
|
||||
"session_id", session.ID)
|
||||
}
|
||||
|
||||
slog.Info("created local approval",
|
||||
"approval_id", approval.ID,
|
||||
"session_id", session.ID,
|
||||
"tool_name", toolName)
|
||||
|
||||
return approval.ID, nil
|
||||
}
|
||||
|
||||
// ApproveFunctionCall approves a function call
|
||||
func (m *DefaultManager) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
|
||||
// First check if we have this function call
|
||||
fc, err := m.Store.GetFunctionCall(callID)
|
||||
// GetPendingApprovals retrieves pending approvals for a session
|
||||
func (m *manager) GetPendingApprovals(ctx context.Context, sessionID string) ([]*store.Approval, error) {
|
||||
approvals, err := m.store.GetPendingApprovals(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("function call not found: %w", err)
|
||||
return nil, fmt.Errorf("failed to get pending approvals: %w", err)
|
||||
}
|
||||
return approvals, nil
|
||||
}
|
||||
|
||||
// GetApproval retrieves a specific approval by ID
|
||||
func (m *manager) GetApproval(ctx context.Context, id string) (*store.Approval, error) {
|
||||
approval, err := m.store.GetApproval(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get approval: %w", err)
|
||||
}
|
||||
return approval, nil
|
||||
}
|
||||
|
||||
// ApproveToolCall approves a tool call
|
||||
func (m *manager) ApproveToolCall(ctx context.Context, id string, comment string) error {
|
||||
// Get the approval first
|
||||
approval, err := m.store.GetApproval(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get approval: %w", err)
|
||||
}
|
||||
|
||||
// Send approval to API
|
||||
if err := m.Client.ApproveFunctionCall(ctx, callID, comment); err != nil {
|
||||
// Handle conflict error specially
|
||||
conflictErr := m.handleConflictError(ctx, err, callID, "function_call")
|
||||
if errors.Is(conflictErr, ErrAlreadyResponded) {
|
||||
// Return the already responded error directly
|
||||
return conflictErr
|
||||
}
|
||||
if conflictErr != nil {
|
||||
return fmt.Errorf("failed to approve function call: %w", conflictErr)
|
||||
}
|
||||
// Update approval status
|
||||
if err := m.store.UpdateApprovalResponse(ctx, id, store.ApprovalStatusLocalApproved, comment); err != nil {
|
||||
return fmt.Errorf("failed to update approval: %w", err)
|
||||
}
|
||||
|
||||
// Mark as responded in local store
|
||||
if err := m.Store.MarkFunctionCallResponded(callID); err != nil {
|
||||
return fmt.Errorf("failed to update local state: %w", err)
|
||||
}
|
||||
|
||||
// Update approval status in database if we have a conversation store
|
||||
if m.ConversationStore != nil && fc != nil {
|
||||
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusApproved); err != nil {
|
||||
slog.Error("failed to update approval status in database", "error", err)
|
||||
// Don't fail the whole operation for this
|
||||
}
|
||||
|
||||
// Update session status back to running since approval is resolved
|
||||
if fc.RunID != "" {
|
||||
// Look up the session by run_id
|
||||
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
|
||||
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
|
||||
runningStatus := store.SessionStatusRunning
|
||||
update := store.SessionUpdate{
|
||||
Status: &runningStatus,
|
||||
}
|
||||
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status to running",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
} else {
|
||||
slog.Info("updated session status back to running after approval",
|
||||
"session_id", session.ID,
|
||||
"approval_id", callID)
|
||||
|
||||
// Publish session status change event
|
||||
if m.EventBus != nil {
|
||||
m.EventBus.Publish(bus.Event{
|
||||
Type: bus.EventSessionStatusChanged,
|
||||
// TODO(4): Can this be a static type later on? Why isn't it currently? Is this because of JSON RPC or a go thing?
|
||||
Data: map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
"run_id": fc.RunID,
|
||||
"old_status": string(store.SessionStatusWaitingInput),
|
||||
"new_status": string(store.SessionStatusRunning),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update correlation status in conversation events
|
||||
if err := m.store.UpdateApprovalStatus(ctx, id, store.ApprovalStatusApproved); err != nil {
|
||||
slog.Warn("failed to update approval status in conversation events",
|
||||
"error", err,
|
||||
"approval_id", id)
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if m.EventBus != nil && fc != nil {
|
||||
m.EventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Data: map[string]interface{}{
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"run_id": fc.RunID,
|
||||
"decision": "approved",
|
||||
"comment": comment,
|
||||
},
|
||||
})
|
||||
m.publishApprovalResolvedEvent(approval, true, comment)
|
||||
|
||||
// Update session status back to running
|
||||
if err := m.updateSessionStatus(ctx, approval.SessionID, store.SessionStatusRunning); err != nil {
|
||||
slog.Warn("failed to update session status",
|
||||
"error", err,
|
||||
"session_id", approval.SessionID)
|
||||
}
|
||||
|
||||
slog.Info("approved tool call",
|
||||
"approval_id", id,
|
||||
"comment", comment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DenyToolCall denies a tool call
|
||||
func (m *manager) DenyToolCall(ctx context.Context, id string, reason string) error {
|
||||
// Get the approval first
|
||||
approval, err := m.store.GetApproval(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get approval: %w", err)
|
||||
}
|
||||
|
||||
// Update approval status
|
||||
if err := m.store.UpdateApprovalResponse(ctx, id, store.ApprovalStatusLocalDenied, reason); err != nil {
|
||||
return fmt.Errorf("failed to update approval: %w", err)
|
||||
}
|
||||
|
||||
// Update correlation status in conversation events
|
||||
if err := m.store.UpdateApprovalStatus(ctx, id, store.ApprovalStatusDenied); err != nil {
|
||||
slog.Warn("failed to update approval status in conversation events",
|
||||
"error", err,
|
||||
"approval_id", id)
|
||||
}
|
||||
|
||||
// Publish event
|
||||
m.publishApprovalResolvedEvent(approval, false, reason)
|
||||
|
||||
// Update session status back to running
|
||||
if err := m.updateSessionStatus(ctx, approval.SessionID, store.SessionStatusRunning); err != nil {
|
||||
slog.Warn("failed to update session status",
|
||||
"error", err,
|
||||
"session_id", approval.SessionID)
|
||||
}
|
||||
|
||||
slog.Info("denied tool call",
|
||||
"approval_id", id,
|
||||
"reason", reason)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// correlateApproval tries to correlate an approval with a tool call
|
||||
func (m *manager) correlateApproval(ctx context.Context, approval *store.Approval) error {
|
||||
// Find the most recent uncorrelated pending tool call
|
||||
toolCall, err := m.store.GetUncorrelatedPendingToolCall(ctx, approval.SessionID, approval.ToolName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find pending tool call: %w", err)
|
||||
}
|
||||
if toolCall == nil {
|
||||
return fmt.Errorf("no matching tool call found")
|
||||
}
|
||||
|
||||
// Correlate by tool ID
|
||||
if err := m.store.CorrelateApprovalByToolID(ctx, approval.SessionID, toolCall.ToolID, approval.ID); err != nil {
|
||||
return fmt.Errorf("failed to correlate approval: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DenyFunctionCall denies a function call
|
||||
func (m *DefaultManager) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
|
||||
// First check if we have this function call
|
||||
fc, err := m.Store.GetFunctionCall(callID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("function call not found: %w", err)
|
||||
}
|
||||
|
||||
// Send denial to API
|
||||
if err := m.Client.DenyFunctionCall(ctx, callID, reason); err != nil {
|
||||
// Handle conflict error specially
|
||||
conflictErr := m.handleConflictError(ctx, err, callID, "function_call")
|
||||
if errors.Is(conflictErr, ErrAlreadyResponded) {
|
||||
// Return the already responded error directly
|
||||
return conflictErr
|
||||
}
|
||||
if conflictErr != nil {
|
||||
return fmt.Errorf("failed to deny function call: %w", conflictErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as responded in local store
|
||||
if err := m.Store.MarkFunctionCallResponded(callID); err != nil {
|
||||
return fmt.Errorf("failed to update local state: %w", err)
|
||||
}
|
||||
|
||||
// Update approval status in database if we have a conversation store
|
||||
if m.ConversationStore != nil && fc != nil {
|
||||
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusDenied); err != nil {
|
||||
slog.Error("failed to update approval status in database", "error", err)
|
||||
// Don't fail the whole operation for this
|
||||
}
|
||||
|
||||
// Update session status back to running since approval is resolved (denied)
|
||||
if fc.RunID != "" {
|
||||
// Look up the session by run_id
|
||||
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
|
||||
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
|
||||
runningStatus := store.SessionStatusRunning
|
||||
update := store.SessionUpdate{
|
||||
Status: &runningStatus,
|
||||
}
|
||||
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status to running",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
} else {
|
||||
slog.Info("updated session status back to running after denial",
|
||||
"session_id", session.ID,
|
||||
"approval_id", callID)
|
||||
|
||||
// Publish session status change event
|
||||
if m.EventBus != nil {
|
||||
m.EventBus.Publish(bus.Event{
|
||||
Type: bus.EventSessionStatusChanged,
|
||||
Data: map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
"run_id": fc.RunID,
|
||||
"old_status": string(store.SessionStatusWaitingInput),
|
||||
"new_status": string(store.SessionStatusRunning),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if m.EventBus != nil && fc != nil {
|
||||
m.EventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
// publishNewApprovalEvent publishes an event when a new approval is created
|
||||
func (m *manager) publishNewApprovalEvent(approval *store.Approval) {
|
||||
if m.eventBus != nil {
|
||||
event := bus.Event{
|
||||
Type: bus.EventNewApproval,
|
||||
Timestamp: time.Now(),
|
||||
Data: map[string]interface{}{
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"run_id": fc.RunID,
|
||||
"decision": "denied",
|
||||
"reason": reason,
|
||||
"approval_id": approval.ID,
|
||||
"session_id": approval.SessionID,
|
||||
"tool_name": approval.ToolName,
|
||||
},
|
||||
})
|
||||
}
|
||||
m.eventBus.Publish(event)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RespondToHumanContact sends a response to a human contact request
|
||||
func (m *DefaultManager) RespondToHumanContact(ctx context.Context, callID string, response string) error {
|
||||
// First check if we have this human contact
|
||||
hc, err := m.Store.GetHumanContact(callID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("human contact not found: %w", err)
|
||||
}
|
||||
|
||||
// Send response to API
|
||||
if err := m.Client.RespondToHumanContact(ctx, callID, response); err != nil {
|
||||
// Handle conflict error specially
|
||||
conflictErr := m.handleConflictError(ctx, err, callID, "human_contact")
|
||||
if errors.Is(conflictErr, ErrAlreadyResponded) {
|
||||
// Return the already responded error directly
|
||||
return conflictErr
|
||||
}
|
||||
if conflictErr != nil {
|
||||
return fmt.Errorf("failed to respond to human contact: %w", conflictErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as responded in local store
|
||||
if err := m.Store.MarkHumanContactResponded(callID); err != nil {
|
||||
return fmt.Errorf("failed to update local state: %w", err)
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if m.EventBus != nil && hc != nil {
|
||||
m.EventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
// publishApprovalResolvedEvent publishes an event when an approval is resolved
|
||||
func (m *manager) publishApprovalResolvedEvent(approval *store.Approval, approved bool, responseText string) {
|
||||
if m.eventBus != nil {
|
||||
event := bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Timestamp: time.Now(),
|
||||
Data: map[string]interface{}{
|
||||
"type": "human_contact",
|
||||
"call_id": callID,
|
||||
"run_id": hc.RunID,
|
||||
"decision": "responded",
|
||||
"response": response,
|
||||
"approval_id": approval.ID,
|
||||
"session_id": approval.SessionID,
|
||||
"approved": approved,
|
||||
"response_text": responseText,
|
||||
},
|
||||
})
|
||||
}
|
||||
m.eventBus.Publish(event)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReconcileApprovalsForSession reconciles approvals for a session after restart
|
||||
func (m *DefaultManager) ReconcileApprovalsForSession(ctx context.Context, runID string) error {
|
||||
if m.Poller == nil {
|
||||
return nil // No poller configured
|
||||
// updateSessionStatus updates the session status
|
||||
func (m *manager) updateSessionStatus(ctx context.Context, sessionID, status string) error {
|
||||
updates := store.SessionUpdate{
|
||||
Status: &status,
|
||||
LastActivityAt: &[]time.Time{time.Now()}[0],
|
||||
}
|
||||
return m.Poller.ReconcileApprovalsForSession(ctx, runID)
|
||||
return m.store.UpdateSession(ctx, sessionID, updates)
|
||||
}
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestManager_SessionStatusTransitions(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupTest func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string // returns callID
|
||||
expectedStatus string
|
||||
}{
|
||||
{
|
||||
name: "approve_updates_session_to_running",
|
||||
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
|
||||
callID := "fc-approve-123"
|
||||
runID := "run-approve-456"
|
||||
sessionID := "sess-approve-789"
|
||||
|
||||
// Function call exists
|
||||
fc := &humanlayer.FunctionCall{
|
||||
CallID: callID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
},
|
||||
}
|
||||
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).AnyTimes()
|
||||
|
||||
// API approval succeeds
|
||||
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
|
||||
|
||||
// Mark as responded
|
||||
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
|
||||
|
||||
// Update approval status
|
||||
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
|
||||
|
||||
// Get session by run_id
|
||||
session := &store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
Status: store.SessionStatusWaitingInput,
|
||||
}
|
||||
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
|
||||
|
||||
// Expect status update to running
|
||||
convStore.EXPECT().
|
||||
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
|
||||
if update.Status == nil || *update.Status != store.SessionStatusRunning {
|
||||
t.Errorf("expected status update to running, got %v", update.Status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return callID
|
||||
},
|
||||
expectedStatus: store.SessionStatusRunning,
|
||||
},
|
||||
{
|
||||
name: "deny_updates_session_to_running",
|
||||
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
|
||||
callID := "fc-deny-123"
|
||||
runID := "run-deny-456"
|
||||
sessionID := "sess-deny-789"
|
||||
|
||||
// Function call exists
|
||||
fc := &humanlayer.FunctionCall{
|
||||
CallID: callID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
},
|
||||
}
|
||||
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(1)
|
||||
|
||||
// API denial succeeds
|
||||
mockClient.EXPECT().DenyFunctionCall(gomock.Any(), callID, "test reason").Return(nil)
|
||||
|
||||
// Mark as responded
|
||||
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
|
||||
|
||||
// Update approval status
|
||||
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusDenied).Return(nil)
|
||||
|
||||
// Get session by run_id
|
||||
session := &store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
Status: store.SessionStatusWaitingInput,
|
||||
}
|
||||
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
|
||||
|
||||
// Expect status update to running
|
||||
convStore.EXPECT().
|
||||
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
|
||||
if update.Status == nil || *update.Status != store.SessionStatusRunning {
|
||||
t.Errorf("expected status update to running, got %v", update.Status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return callID
|
||||
},
|
||||
expectedStatus: store.SessionStatusRunning,
|
||||
},
|
||||
{
|
||||
name: "no_status_update_for_non_waiting_session",
|
||||
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
|
||||
callID := "fc-nochange-123"
|
||||
runID := "run-nochange-456"
|
||||
sessionID := "sess-nochange-789"
|
||||
|
||||
// Function call exists
|
||||
fc := &humanlayer.FunctionCall{
|
||||
CallID: callID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
},
|
||||
}
|
||||
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(1)
|
||||
|
||||
// API approval succeeds
|
||||
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
|
||||
|
||||
// Mark as responded
|
||||
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
|
||||
|
||||
// Update approval status
|
||||
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
|
||||
|
||||
// Get session - session is already running, not waiting
|
||||
session := &store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
Status: store.SessionStatusRunning, // Already running
|
||||
}
|
||||
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
|
||||
|
||||
// No status update should happen since session is not waiting
|
||||
|
||||
return callID
|
||||
},
|
||||
expectedStatus: store.SessionStatusRunning,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a new controller for each test to isolate expectations
|
||||
testCtrl := gomock.NewController(t)
|
||||
defer testCtrl.Finish()
|
||||
|
||||
// Create mocks
|
||||
mockClient := NewMockAPIClient(testCtrl)
|
||||
mockStore := NewMockStore(testCtrl)
|
||||
convStore := store.NewMockConversationStore(testCtrl)
|
||||
|
||||
// Create manager
|
||||
manager := &DefaultManager{
|
||||
Client: mockClient,
|
||||
Store: mockStore,
|
||||
ConversationStore: convStore,
|
||||
}
|
||||
|
||||
// Setup test expectations and get callID
|
||||
callID := tt.setupTest(mockClient, mockStore, convStore)
|
||||
|
||||
// Execute based on test type
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
if tt.name == "deny_updates_session_to_running" {
|
||||
err = manager.DenyFunctionCall(ctx, callID, "test reason")
|
||||
} else {
|
||||
err = manager.ApproveFunctionCall(ctx, callID, "test comment")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Expectations are verified by gomock
|
||||
})
|
||||
}
|
||||
}
|
||||
263
hld/approval/manager_test.go
Normal file
263
hld/approval/manager_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestManager_CreateApproval(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
mockEventBus := bus.NewMockEventBus(ctrl)
|
||||
|
||||
manager := NewManager(mockStore, mockEventBus)
|
||||
|
||||
ctx := context.Background()
|
||||
runID := "test-run-123"
|
||||
sessionID := "test-session-456"
|
||||
toolName := "Write"
|
||||
toolInput := json.RawMessage(`{"file": "test.txt", "content": "hello"}`)
|
||||
|
||||
// Mock getting session by run ID
|
||||
mockStore.EXPECT().GetSessionByRunID(ctx, runID).Return(&store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
}, nil)
|
||||
|
||||
// Mock creating approval
|
||||
mockStore.EXPECT().CreateApproval(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, approval *store.Approval) error {
|
||||
assert.Equal(t, runID, approval.RunID)
|
||||
assert.Equal(t, sessionID, approval.SessionID)
|
||||
assert.Equal(t, store.ApprovalStatusLocalPending, approval.Status)
|
||||
assert.Equal(t, toolName, approval.ToolName)
|
||||
assert.Equal(t, toolInput, approval.ToolInput)
|
||||
assert.NotEmpty(t, approval.ID)
|
||||
assert.True(t, strings.HasPrefix(approval.ID, "local-"))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Mock correlation attempt - it's ok if it fails
|
||||
mockStore.EXPECT().GetUncorrelatedPendingToolCall(ctx, sessionID, toolName).Return(nil, nil)
|
||||
|
||||
// Mock event publishing
|
||||
mockEventBus.EXPECT().Publish(gomock.Any()).Do(func(event bus.Event) {
|
||||
assert.Equal(t, bus.EventNewApproval, event.Type)
|
||||
assert.Equal(t, sessionID, event.Data["session_id"])
|
||||
assert.Equal(t, toolName, event.Data["tool_name"])
|
||||
})
|
||||
|
||||
// Mock session status update
|
||||
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
|
||||
|
||||
// Create approval
|
||||
approvalID, err := manager.CreateApproval(ctx, runID, toolName, toolInput)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, approvalID)
|
||||
assert.True(t, strings.HasPrefix(approvalID, "local-"))
|
||||
}
|
||||
|
||||
func TestManager_CreateApproval_SessionNotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
mockEventBus := bus.NewMockEventBus(ctrl)
|
||||
|
||||
manager := NewManager(mockStore, mockEventBus)
|
||||
|
||||
ctx := context.Background()
|
||||
runID := "test-run-123"
|
||||
toolName := "Write"
|
||||
toolInput := json.RawMessage(`{"file": "test.txt"}`)
|
||||
|
||||
// Mock getting session by run ID - returns nil
|
||||
mockStore.EXPECT().GetSessionByRunID(ctx, runID).Return(nil, nil)
|
||||
|
||||
// Create approval should fail
|
||||
_, err := manager.CreateApproval(ctx, runID, toolName, toolInput)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "session not found")
|
||||
}
|
||||
|
||||
func TestManager_GetPendingApprovals(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
mockEventBus := bus.NewMockEventBus(ctrl)
|
||||
|
||||
manager := NewManager(mockStore, mockEventBus)
|
||||
|
||||
ctx := context.Background()
|
||||
sessionID := "test-session-456"
|
||||
|
||||
expectedApprovals := []*store.Approval{
|
||||
{
|
||||
ID: "local-approval-1",
|
||||
SessionID: sessionID,
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
ToolName: "Write",
|
||||
},
|
||||
{
|
||||
ID: "local-approval-2",
|
||||
SessionID: sessionID,
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
ToolName: "Execute",
|
||||
},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().GetPendingApprovals(ctx, sessionID).Return(expectedApprovals, nil)
|
||||
|
||||
approvals, err := manager.GetPendingApprovals(ctx, sessionID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedApprovals, approvals)
|
||||
}
|
||||
|
||||
func TestManager_ApproveToolCall(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
mockEventBus := bus.NewMockEventBus(ctrl)
|
||||
|
||||
manager := NewManager(mockStore, mockEventBus)
|
||||
|
||||
ctx := context.Background()
|
||||
approvalID := "local-approval-123"
|
||||
sessionID := "test-session-456"
|
||||
comment := "Looks good!"
|
||||
|
||||
approval := &store.Approval{
|
||||
ID: approvalID,
|
||||
SessionID: sessionID,
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
ToolName: "Write",
|
||||
}
|
||||
|
||||
// Mock getting approval
|
||||
mockStore.EXPECT().GetApproval(ctx, approvalID).Return(approval, nil)
|
||||
|
||||
// Mock updating approval response
|
||||
mockStore.EXPECT().UpdateApprovalResponse(ctx, approvalID, store.ApprovalStatusLocalApproved, comment).Return(nil)
|
||||
|
||||
// Mock updating approval status in conversation events
|
||||
mockStore.EXPECT().UpdateApprovalStatus(ctx, approvalID, store.ApprovalStatusApproved).Return(nil)
|
||||
|
||||
// Mock event publishing
|
||||
mockEventBus.EXPECT().Publish(gomock.Any()).Do(func(event bus.Event) {
|
||||
assert.Equal(t, bus.EventApprovalResolved, event.Type)
|
||||
assert.Equal(t, approvalID, event.Data["approval_id"])
|
||||
assert.Equal(t, sessionID, event.Data["session_id"])
|
||||
assert.Equal(t, true, event.Data["approved"])
|
||||
assert.Equal(t, comment, event.Data["response_text"])
|
||||
})
|
||||
|
||||
// Mock session status update
|
||||
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
|
||||
|
||||
err := manager.ApproveToolCall(ctx, approvalID, comment)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestManager_DenyToolCall(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
mockEventBus := bus.NewMockEventBus(ctrl)
|
||||
|
||||
manager := NewManager(mockStore, mockEventBus)
|
||||
|
||||
ctx := context.Background()
|
||||
approvalID := "local-approval-123"
|
||||
sessionID := "test-session-456"
|
||||
reason := "Not safe to execute"
|
||||
|
||||
approval := &store.Approval{
|
||||
ID: approvalID,
|
||||
SessionID: sessionID,
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
ToolName: "Execute",
|
||||
}
|
||||
|
||||
// Mock getting approval
|
||||
mockStore.EXPECT().GetApproval(ctx, approvalID).Return(approval, nil)
|
||||
|
||||
// Mock updating approval response
|
||||
mockStore.EXPECT().UpdateApprovalResponse(ctx, approvalID, store.ApprovalStatusLocalDenied, reason).Return(nil)
|
||||
|
||||
// Mock updating approval status in conversation events
|
||||
mockStore.EXPECT().UpdateApprovalStatus(ctx, approvalID, store.ApprovalStatusDenied).Return(nil)
|
||||
|
||||
// Mock event publishing
|
||||
mockEventBus.EXPECT().Publish(gomock.Any()).Do(func(event bus.Event) {
|
||||
assert.Equal(t, bus.EventApprovalResolved, event.Type)
|
||||
assert.Equal(t, approvalID, event.Data["approval_id"])
|
||||
assert.Equal(t, sessionID, event.Data["session_id"])
|
||||
assert.Equal(t, false, event.Data["approved"])
|
||||
assert.Equal(t, reason, event.Data["response_text"])
|
||||
})
|
||||
|
||||
// Mock session status update
|
||||
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
|
||||
|
||||
err := manager.DenyToolCall(ctx, approvalID, reason)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestManager_CorrelateApproval(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
mockEventBus := bus.NewMockEventBus(ctrl)
|
||||
|
||||
manager := NewManager(mockStore, mockEventBus)
|
||||
|
||||
ctx := context.Background()
|
||||
runID := "test-run-123"
|
||||
sessionID := "test-session-456"
|
||||
toolName := "Write"
|
||||
toolInput := json.RawMessage(`{"file": "test.txt"}`)
|
||||
|
||||
// Mock getting session by run ID
|
||||
mockStore.EXPECT().GetSessionByRunID(ctx, runID).Return(&store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
}, nil)
|
||||
|
||||
// Mock creating approval
|
||||
mockStore.EXPECT().CreateApproval(ctx, gomock.Any()).Return(nil)
|
||||
|
||||
// Mock successful correlation
|
||||
pendingToolCall := &store.ConversationEvent{
|
||||
ID: 123,
|
||||
ToolID: "tool-123",
|
||||
ToolName: toolName,
|
||||
}
|
||||
mockStore.EXPECT().GetUncorrelatedPendingToolCall(ctx, sessionID, toolName).Return(pendingToolCall, nil)
|
||||
|
||||
// Mock correlating by tool ID
|
||||
mockStore.EXPECT().CorrelateApprovalByToolID(ctx, sessionID, "tool-123", gomock.Any()).Return(nil)
|
||||
|
||||
// Mock event publishing
|
||||
mockEventBus.EXPECT().Publish(gomock.Any())
|
||||
|
||||
// Mock session status update
|
||||
mockStore.EXPECT().UpdateSession(ctx, sessionID, gomock.Any()).Return(nil)
|
||||
|
||||
// Create approval (which will attempt correlation)
|
||||
approvalID, err := manager.CreateApproval(ctx, runID, toolName, toolInput)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, approvalID)
|
||||
}
|
||||
@@ -1,561 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
// Poller polls the HumanLayer API for pending approvals
|
||||
type Poller struct {
|
||||
client APIClient
|
||||
store Store
|
||||
conversationStore store.ConversationStore
|
||||
eventBus bus.EventBus
|
||||
interval time.Duration
|
||||
maxBackoff time.Duration
|
||||
backoffFactor float64
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
failureCount int
|
||||
}
|
||||
|
||||
// NewPoller creates a new approval poller
|
||||
func NewPoller(client APIClient, store Store, conversationStore store.ConversationStore, interval time.Duration, eventBus bus.EventBus) *Poller {
|
||||
if interval <= 0 {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
return &Poller{
|
||||
client: client,
|
||||
store: store,
|
||||
conversationStore: conversationStore,
|
||||
eventBus: eventBus,
|
||||
interval: interval,
|
||||
maxBackoff: 5 * time.Minute,
|
||||
backoffFactor: 2.0,
|
||||
failureCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins polling for approvals
|
||||
func (p *Poller) Start(ctx context.Context) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.cancel != nil {
|
||||
return fmt.Errorf("poller already started")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
p.cancel = cancel
|
||||
|
||||
go p.pollLoop(ctx)
|
||||
slog.Info("approval poller started", "interval", p.interval)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the polling loop
|
||||
func (p *Poller) Stop() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.cancel != nil {
|
||||
p.cancel()
|
||||
p.cancel = nil
|
||||
slog.Info("approval poller stopped")
|
||||
}
|
||||
}
|
||||
|
||||
// IsRunning returns true if the poller is currently running
|
||||
func (p *Poller) IsRunning() bool {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.cancel != nil
|
||||
}
|
||||
|
||||
// pollLoop continuously polls for approvals
|
||||
func (p *Poller) pollLoop(ctx context.Context) {
|
||||
// Poll immediately on start
|
||||
p.poll(ctx)
|
||||
|
||||
for {
|
||||
// Calculate next poll interval based on failure count
|
||||
interval := p.calculateInterval()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(interval):
|
||||
p.poll(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculateInterval returns the next polling interval with exponential backoff
|
||||
func (p *Poller) calculateInterval() time.Duration {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.failureCount == 0 {
|
||||
return p.interval
|
||||
}
|
||||
|
||||
// Calculate backoff: interval * (backoffFactor ^ failureCount)
|
||||
backoff := float64(p.interval)
|
||||
for i := 0; i < p.failureCount; i++ {
|
||||
backoff *= p.backoffFactor
|
||||
if time.Duration(backoff) > p.maxBackoff {
|
||||
return p.maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return time.Duration(backoff)
|
||||
}
|
||||
|
||||
// poll fetches and stores pending approvals
|
||||
func (p *Poller) poll(ctx context.Context) {
|
||||
// Create a timeout context for this poll
|
||||
pollCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var hadError bool
|
||||
|
||||
// Fetch function calls
|
||||
functionCalls, err := p.client.GetPendingFunctionCalls(pollCtx)
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch function calls", "error", err)
|
||||
hadError = true
|
||||
} else {
|
||||
// Reconcile with cached state
|
||||
p.reconcileFunctionCalls(pollCtx, functionCalls)
|
||||
}
|
||||
|
||||
// Fetch human contacts
|
||||
humanContacts, err := p.client.GetPendingHumanContacts(pollCtx)
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch human contacts", "error", err)
|
||||
hadError = true
|
||||
} else {
|
||||
// Reconcile with cached state
|
||||
p.reconcileHumanContacts(pollCtx, humanContacts)
|
||||
}
|
||||
|
||||
// Update failure count based on results
|
||||
p.updateFailureCount(hadError)
|
||||
}
|
||||
|
||||
// reconcileFunctionCalls reconciles fetched function calls with cached state
|
||||
func (p *Poller) reconcileFunctionCalls(ctx context.Context, functionCalls []humanlayer.FunctionCall) {
|
||||
// Get all cached function calls
|
||||
cachedCalls, err := p.store.GetAllCachedFunctionCalls()
|
||||
if err != nil {
|
||||
slog.Error("failed to get cached function calls", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a map of fetched calls for quick lookup
|
||||
fetchedMap := make(map[string]*humanlayer.FunctionCall)
|
||||
for i := range functionCalls {
|
||||
fc := &functionCalls[i]
|
||||
fetchedMap[fc.CallID] = fc
|
||||
}
|
||||
|
||||
// Check for calls that were resolved externally
|
||||
removedCount := 0
|
||||
for _, cached := range cachedCalls {
|
||||
if _, exists := fetchedMap[cached.CallID]; !exists {
|
||||
// This approval is no longer pending, it was resolved externally
|
||||
slog.Info("detected externally resolved function call",
|
||||
"call_id", cached.CallID,
|
||||
"run_id", cached.RunID)
|
||||
|
||||
// Remove from local store
|
||||
if err := p.store.RemoveFunctionCall(cached.CallID); err != nil {
|
||||
slog.Error("failed to remove resolved function call", "error", err)
|
||||
} else {
|
||||
removedCount++
|
||||
}
|
||||
|
||||
// Update database status if we have a conversation store
|
||||
if p.conversationStore != nil {
|
||||
if err := p.conversationStore.UpdateApprovalStatus(ctx, cached.CallID, store.ApprovalStatusResolved); err != nil {
|
||||
slog.Error("failed to update approval status in database", "error", err)
|
||||
}
|
||||
|
||||
// Update session status back to running if it was waiting
|
||||
if cached.RunID != "" {
|
||||
session, err := p.conversationStore.GetSessionByRunID(ctx, cached.RunID)
|
||||
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
|
||||
runningStatus := store.SessionStatusRunning
|
||||
update := store.SessionUpdate{
|
||||
Status: &runningStatus,
|
||||
}
|
||||
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status", "error", err)
|
||||
} else {
|
||||
// Publish session status change event
|
||||
if p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventSessionStatusChanged,
|
||||
Data: map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
"run_id": cached.RunID,
|
||||
"old_status": string(store.SessionStatusWaitingInput),
|
||||
"new_status": string(store.SessionStatusRunning),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Data: map[string]interface{}{
|
||||
"type": "function_call",
|
||||
"call_id": cached.CallID,
|
||||
"run_id": cached.RunID,
|
||||
"resolved_externally": true,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store new/updated function calls
|
||||
newCount := 0
|
||||
for _, fc := range functionCalls {
|
||||
// Check if this is a new approval
|
||||
if existing, err := p.store.GetFunctionCall(fc.CallID); err != nil || existing == nil {
|
||||
newCount++
|
||||
}
|
||||
|
||||
if err := p.store.StoreFunctionCall(fc); err != nil {
|
||||
slog.Error("failed to store function call", "call_id", fc.CallID, "error", err)
|
||||
}
|
||||
|
||||
// Correlate with tool calls in the database
|
||||
if p.conversationStore != nil && fc.RunID != "" {
|
||||
p.correlateApproval(ctx, fc)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("reconciled function calls",
|
||||
"fetched", len(functionCalls),
|
||||
"new", newCount,
|
||||
"removed", removedCount)
|
||||
|
||||
// Publish event if we have new approvals
|
||||
if newCount > 0 && p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventNewApproval,
|
||||
Data: map[string]interface{}{
|
||||
"type": "function_call",
|
||||
"count": newCount,
|
||||
"total": len(functionCalls),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileHumanContacts reconciles fetched human contacts with cached state
|
||||
func (p *Poller) reconcileHumanContacts(ctx context.Context, humanContacts []humanlayer.HumanContact) {
|
||||
// Get all cached human contacts
|
||||
cachedContacts, err := p.store.GetAllCachedHumanContacts()
|
||||
if err != nil {
|
||||
slog.Error("failed to get cached human contacts", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a map of fetched contacts for quick lookup
|
||||
fetchedMap := make(map[string]*humanlayer.HumanContact)
|
||||
for i := range humanContacts {
|
||||
hc := &humanContacts[i]
|
||||
fetchedMap[hc.CallID] = hc
|
||||
}
|
||||
|
||||
// Check for contacts that were resolved externally
|
||||
removedCount := 0
|
||||
for _, cached := range cachedContacts {
|
||||
if _, exists := fetchedMap[cached.CallID]; !exists {
|
||||
// This contact is no longer pending, it was resolved externally
|
||||
slog.Info("detected externally resolved human contact",
|
||||
"call_id", cached.CallID,
|
||||
"run_id", cached.RunID)
|
||||
|
||||
// Remove from local store
|
||||
if err := p.store.RemoveHumanContact(cached.CallID); err != nil {
|
||||
slog.Error("failed to remove resolved human contact", "error", err)
|
||||
} else {
|
||||
removedCount++
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Data: map[string]interface{}{
|
||||
"type": "human_contact",
|
||||
"call_id": cached.CallID,
|
||||
"run_id": cached.RunID,
|
||||
"resolved_externally": true,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store new/updated human contacts
|
||||
newCount := 0
|
||||
for _, hc := range humanContacts {
|
||||
// Check if this is a new approval
|
||||
if existing, err := p.store.GetHumanContact(hc.CallID); err != nil || existing == nil {
|
||||
newCount++
|
||||
}
|
||||
|
||||
if err := p.store.StoreHumanContact(hc); err != nil {
|
||||
slog.Error("failed to store human contact", "call_id", hc.CallID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("reconciled human contacts",
|
||||
"fetched", len(humanContacts),
|
||||
"new", newCount,
|
||||
"removed", removedCount)
|
||||
|
||||
// Publish event if we have new approvals
|
||||
if newCount > 0 && p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventNewApproval,
|
||||
Data: map[string]interface{}{
|
||||
"type": "human_contact",
|
||||
"count": newCount,
|
||||
"total": len(humanContacts),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ReconcileApprovalsForSession checks for any approvals that might belong to a session
|
||||
// This is useful when a session restarts and needs to reclaim pending approvals
|
||||
func (p *Poller) ReconcileApprovalsForSession(ctx context.Context, runID string) error {
|
||||
// Get all cached function calls
|
||||
functionCalls, err := p.store.GetAllCachedFunctionCalls()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cached function calls: %w", err)
|
||||
}
|
||||
|
||||
// Check each one to see if it matches this run_id
|
||||
for _, fc := range functionCalls {
|
||||
if fc.RunID == runID {
|
||||
// Re-correlate this approval
|
||||
p.correlateApproval(ctx, fc)
|
||||
}
|
||||
}
|
||||
|
||||
// Also check human contacts
|
||||
humanContacts, err := p.store.GetAllCachedHumanContacts()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cached human contacts: %w", err)
|
||||
}
|
||||
|
||||
for _, hc := range humanContacts {
|
||||
if hc.RunID == runID {
|
||||
slog.Info("found orphaned human contact for session",
|
||||
"call_id", hc.CallID,
|
||||
"run_id", runID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateFailureCount updates the failure count for backoff calculation
|
||||
func (p *Poller) updateFailureCount(hadError bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if hadError {
|
||||
p.failureCount++
|
||||
nextInterval := p.calculateIntervalLocked()
|
||||
slog.Warn("poll failed, backing off",
|
||||
"failure_count", p.failureCount,
|
||||
"next_interval", nextInterval)
|
||||
} else {
|
||||
if p.failureCount > 0 {
|
||||
slog.Info("poll succeeded, resetting backoff")
|
||||
p.failureCount = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculateIntervalLocked returns the next interval (must be called with lock held)
|
||||
func (p *Poller) calculateIntervalLocked() time.Duration {
|
||||
if p.failureCount == 0 {
|
||||
return p.interval
|
||||
}
|
||||
|
||||
// Calculate backoff: interval * (backoffFactor ^ failureCount)
|
||||
backoff := float64(p.interval)
|
||||
for i := 0; i < p.failureCount; i++ {
|
||||
backoff *= p.backoffFactor
|
||||
if time.Duration(backoff) > p.maxBackoff {
|
||||
return p.maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return time.Duration(backoff)
|
||||
}
|
||||
|
||||
// correlateApproval attempts to match an approval with a pending tool call
|
||||
func (p *Poller) correlateApproval(ctx context.Context, fc humanlayer.FunctionCall) {
|
||||
// Find the session with this run_id
|
||||
session, err := p.conversationStore.GetSessionByRunID(ctx, fc.RunID)
|
||||
if err != nil {
|
||||
slog.Error("failed to get session for correlation",
|
||||
"run_id", fc.RunID,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
// This is expected for approvals that aren't from our Claude sessions
|
||||
slog.Debug("no matching session for approval",
|
||||
"run_id", fc.RunID,
|
||||
"approval_id", fc.CallID)
|
||||
return
|
||||
}
|
||||
|
||||
toolName := fc.Spec.Fn
|
||||
// Try to find an uncorrelated pending tool call for this session and tool
|
||||
// Use the specialized method that only finds tool calls without approvals
|
||||
var toolCall *store.ConversationEvent
|
||||
|
||||
// Check if the store has the GetUncorrelatedPendingToolCall method
|
||||
if uncorrelatedStore, ok := p.conversationStore.(*store.SQLiteStore); ok {
|
||||
toolCall, err = uncorrelatedStore.GetUncorrelatedPendingToolCall(ctx, session.ID, toolName)
|
||||
} else {
|
||||
// Fallback to regular GetPendingToolCall
|
||||
toolCall, err = p.conversationStore.GetPendingToolCall(ctx, session.ID, toolName)
|
||||
}
|
||||
|
||||
if err != nil || toolCall == nil {
|
||||
// For continued sessions, also check the parent session's tool calls
|
||||
if session.ParentSessionID != "" {
|
||||
parentToolCall, err := p.conversationStore.GetPendingToolCall(ctx, session.ParentSessionID, toolName)
|
||||
if err == nil && parentToolCall != nil {
|
||||
// Found in parent session - correlate with parent
|
||||
if err := p.conversationStore.CorrelateApproval(ctx, session.ParentSessionID, toolName, fc.CallID); err != nil {
|
||||
slog.Error("failed to correlate approval with parent session tool call",
|
||||
"approval_id", fc.CallID,
|
||||
"parent_session_id", session.ParentSessionID,
|
||||
"tool_name", toolName,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("correlated approval with parent session tool call",
|
||||
"approval_id", fc.CallID,
|
||||
"parent_session_id", session.ParentSessionID,
|
||||
"current_session_id", session.ID,
|
||||
"tool_name", toolName,
|
||||
"run_id", fc.RunID)
|
||||
|
||||
// Update current session status to waiting_input
|
||||
waitingStatus := store.SessionStatusWaitingInput
|
||||
update := store.SessionUpdate{
|
||||
Status: &waitingStatus,
|
||||
}
|
||||
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status to waiting_input",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
}
|
||||
|
||||
// Publish session status change event
|
||||
if p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventSessionStatusChanged,
|
||||
Data: map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
"run_id": fc.RunID,
|
||||
"old_status": string(store.SessionStatusRunning),
|
||||
"new_status": string(store.SessionStatusWaitingInput),
|
||||
},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("no matching pending tool call for approval",
|
||||
"session_id", session.ID,
|
||||
"run_id", fc.RunID,
|
||||
"tool_name", toolName,
|
||||
"approval_id", fc.CallID)
|
||||
return
|
||||
}
|
||||
|
||||
// Found a matching tool call in one of our sessions - correlate it
|
||||
if err := p.conversationStore.CorrelateApproval(ctx, session.ID, toolName, fc.CallID); err != nil {
|
||||
slog.Error("failed to correlate approval with tool call",
|
||||
"approval_id", fc.CallID,
|
||||
"session_id", session.ID,
|
||||
"tool_name", toolName,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("correlated approval with tool call",
|
||||
"approval_id", fc.CallID,
|
||||
"session_id", session.ID,
|
||||
"tool_name", toolName,
|
||||
"run_id", fc.RunID)
|
||||
|
||||
// Publish event to notify that approval has been correlated
|
||||
if p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Data: map[string]interface{}{
|
||||
"approval_id": fc.CallID,
|
||||
"session_id": session.ID,
|
||||
"tool_name": toolName,
|
||||
"run_id": fc.RunID,
|
||||
"status": "correlated",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Update session status to waiting_input
|
||||
waitingStatus := store.SessionStatusWaitingInput
|
||||
update := store.SessionUpdate{
|
||||
Status: &waitingStatus,
|
||||
}
|
||||
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status to waiting_input",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
}
|
||||
|
||||
// Publish session status change event
|
||||
if p.eventBus != nil {
|
||||
p.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventSessionStatusChanged,
|
||||
Data: map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
"run_id": fc.RunID,
|
||||
"old_status": string(store.SessionStatusRunning),
|
||||
"new_status": string(store.SessionStatusWaitingInput),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestPoller_Poll(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
// Create mock client with test data
|
||||
mockClient := NewMockAPIClient(ctrl)
|
||||
mockStore := NewMockStore(ctrl)
|
||||
|
||||
testFunctionCalls := []humanlayer.FunctionCall{
|
||||
{CallID: "fc-1", RunID: "run-1"},
|
||||
{CallID: "fc-2", RunID: "run-2"},
|
||||
}
|
||||
testHumanContacts := []humanlayer.HumanContact{
|
||||
{CallID: "hc-1", RunID: "run-1"},
|
||||
}
|
||||
|
||||
// Set expectations
|
||||
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return(testFunctionCalls, nil)
|
||||
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return(testHumanContacts, nil)
|
||||
|
||||
// Expect reconciliation calls
|
||||
mockStore.EXPECT().GetAllCachedFunctionCalls().Return([]humanlayer.FunctionCall{}, nil)
|
||||
mockStore.EXPECT().GetAllCachedHumanContacts().Return([]humanlayer.HumanContact{}, nil)
|
||||
|
||||
// Expect checks for existing approvals
|
||||
mockStore.EXPECT().GetFunctionCall("fc-1").Return(nil, errors.New("not found"))
|
||||
mockStore.EXPECT().GetFunctionCall("fc-2").Return(nil, errors.New("not found"))
|
||||
mockStore.EXPECT().GetHumanContact("hc-1").Return(nil, errors.New("not found"))
|
||||
|
||||
mockStore.EXPECT().StoreFunctionCall(testFunctionCalls[0]).Return(nil)
|
||||
mockStore.EXPECT().StoreFunctionCall(testFunctionCalls[1]).Return(nil)
|
||||
mockStore.EXPECT().StoreHumanContact(testHumanContacts[0]).Return(nil)
|
||||
|
||||
// Create poller with short interval for testing
|
||||
poller := &Poller{
|
||||
client: mockClient,
|
||||
store: mockStore,
|
||||
interval: 10 * time.Millisecond,
|
||||
maxBackoff: 100 * time.Millisecond,
|
||||
backoffFactor: 2.0,
|
||||
}
|
||||
|
||||
// Poll once
|
||||
ctx := context.Background()
|
||||
poller.poll(ctx)
|
||||
|
||||
// Verify no failure count
|
||||
if poller.failureCount != 0 {
|
||||
t.Errorf("expected failure count 0, got %d", poller.failureCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoller_Backoff(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockClient := NewMockAPIClient(ctrl)
|
||||
mockStore := NewMockStore(ctrl)
|
||||
|
||||
apiErr := errors.New("API error")
|
||||
|
||||
// First two calls fail
|
||||
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return(nil, apiErr).Times(2)
|
||||
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return(nil, apiErr).Times(2)
|
||||
|
||||
// Third call succeeds
|
||||
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return([]humanlayer.FunctionCall{}, nil)
|
||||
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return([]humanlayer.HumanContact{}, nil)
|
||||
|
||||
// Expect reconciliation calls for successful poll
|
||||
mockStore.EXPECT().GetAllCachedFunctionCalls().Return([]humanlayer.FunctionCall{}, nil)
|
||||
mockStore.EXPECT().GetAllCachedHumanContacts().Return([]humanlayer.HumanContact{}, nil)
|
||||
|
||||
poller := &Poller{
|
||||
client: mockClient,
|
||||
store: mockStore,
|
||||
interval: 10 * time.Millisecond,
|
||||
maxBackoff: 100 * time.Millisecond,
|
||||
backoffFactor: 2.0,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First poll should fail and increment failure count
|
||||
poller.poll(ctx)
|
||||
if poller.failureCount != 1 {
|
||||
t.Errorf("expected failure count 1, got %d", poller.failureCount)
|
||||
}
|
||||
|
||||
// Calculate expected interval (10ms * 2^1 = 20ms)
|
||||
interval := poller.calculateInterval()
|
||||
if interval != 20*time.Millisecond {
|
||||
t.Errorf("expected interval 20ms, got %v", interval)
|
||||
}
|
||||
|
||||
// Second poll should fail and increment again
|
||||
poller.poll(ctx)
|
||||
if poller.failureCount != 2 {
|
||||
t.Errorf("expected failure count 2, got %d", poller.failureCount)
|
||||
}
|
||||
|
||||
// Calculate expected interval (10ms * 2^2 = 40ms)
|
||||
interval = poller.calculateInterval()
|
||||
if interval != 40*time.Millisecond {
|
||||
t.Errorf("expected interval 40ms, got %v", interval)
|
||||
}
|
||||
|
||||
// Test max backoff
|
||||
poller.failureCount = 10
|
||||
interval = poller.calculateInterval()
|
||||
if interval != poller.maxBackoff {
|
||||
t.Errorf("expected max backoff %v, got %v", poller.maxBackoff, interval)
|
||||
}
|
||||
|
||||
// Test reset on success
|
||||
poller.poll(ctx)
|
||||
if poller.failureCount != 0 {
|
||||
t.Errorf("expected failure count reset to 0, got %d", poller.failureCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoller_StartStop(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockClient := NewMockAPIClient(ctrl)
|
||||
mockStore := NewMockStore(ctrl)
|
||||
|
||||
// Expect at least 2 polls during the test duration
|
||||
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return([]humanlayer.FunctionCall{}, nil).MinTimes(2)
|
||||
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return([]humanlayer.HumanContact{}, nil).MinTimes(2)
|
||||
|
||||
// Expect reconciliation calls
|
||||
mockStore.EXPECT().GetAllCachedFunctionCalls().Return([]humanlayer.FunctionCall{}, nil).MinTimes(2)
|
||||
mockStore.EXPECT().GetAllCachedHumanContacts().Return([]humanlayer.HumanContact{}, nil).MinTimes(2)
|
||||
|
||||
poller := NewPoller(mockClient, mockStore, nil, 50*time.Millisecond, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start poller
|
||||
err := poller.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start poller: %v", err)
|
||||
}
|
||||
|
||||
// Wait for a few polls
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
// Try to start again while running should fail
|
||||
err = poller.Start(ctx)
|
||||
if err == nil {
|
||||
t.Error("expected error starting poller twice")
|
||||
}
|
||||
|
||||
// Stop poller
|
||||
poller.Stop()
|
||||
|
||||
// Verify it's stopped
|
||||
if poller.IsRunning() {
|
||||
t.Error("expected poller to be stopped")
|
||||
}
|
||||
}
|
||||
@@ -2,62 +2,21 @@ package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
)
|
||||
|
||||
// PendingApproval wraps either a function call or human contact
|
||||
type PendingApproval struct {
|
||||
Type string `json:"type"` // "function_call" or "human_contact"
|
||||
FunctionCall *humanlayer.FunctionCall `json:"function_call,omitempty"`
|
||||
HumanContact *humanlayer.HumanContact `json:"human_contact,omitempty"`
|
||||
}
|
||||
|
||||
// Manager defines the interface for managing approvals
|
||||
// Manager defines the interface for managing local approvals
|
||||
type Manager interface {
|
||||
// Lifecycle methods
|
||||
Start(ctx context.Context) error
|
||||
Stop()
|
||||
// Create a new approval
|
||||
CreateApproval(ctx context.Context, runID, toolName string, toolInput json.RawMessage) (string, error)
|
||||
|
||||
// Retrieval methods
|
||||
GetPendingApprovals(sessionID string) ([]PendingApproval, error)
|
||||
GetPendingApprovalsByRunID(runID string) ([]PendingApproval, error)
|
||||
GetPendingApprovals(ctx context.Context, sessionID string) ([]*store.Approval, error)
|
||||
GetApproval(ctx context.Context, id string) (*store.Approval, error)
|
||||
|
||||
// Decision methods
|
||||
ApproveFunctionCall(ctx context.Context, callID string, comment string) error
|
||||
DenyFunctionCall(ctx context.Context, callID string, reason string) error
|
||||
RespondToHumanContact(ctx context.Context, callID string, response string) error
|
||||
|
||||
// Recovery methods
|
||||
ReconcileApprovalsForSession(ctx context.Context, runID string) error
|
||||
}
|
||||
|
||||
// Store manages approval storage and correlation
|
||||
type Store interface {
|
||||
// Storage methods
|
||||
StoreFunctionCall(fc humanlayer.FunctionCall) error
|
||||
StoreHumanContact(hc humanlayer.HumanContact) error
|
||||
|
||||
// Retrieval methods
|
||||
GetFunctionCall(callID string) (*humanlayer.FunctionCall, error)
|
||||
GetHumanContact(callID string) (*humanlayer.HumanContact, error)
|
||||
GetAllPending() ([]PendingApproval, error)
|
||||
GetPendingByRunID(runID string) ([]PendingApproval, error)
|
||||
GetAllCachedFunctionCalls() ([]humanlayer.FunctionCall, error)
|
||||
GetAllCachedHumanContacts() ([]humanlayer.HumanContact, error)
|
||||
|
||||
// Update methods
|
||||
MarkFunctionCallResponded(callID string) error
|
||||
MarkHumanContactResponded(callID string) error
|
||||
RemoveFunctionCall(callID string) error
|
||||
RemoveHumanContact(callID string) error
|
||||
}
|
||||
|
||||
// APIClient defines the interface for interacting with the HumanLayer API
|
||||
type APIClient interface {
|
||||
GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error)
|
||||
GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error)
|
||||
ApproveFunctionCall(ctx context.Context, callID string, comment string) error
|
||||
DenyFunctionCall(ctx context.Context, callID string, reason string) error
|
||||
RespondToHumanContact(ctx context.Context, callID string, response string) error
|
||||
ApproveToolCall(ctx context.Context, id string, comment string) error
|
||||
DenyToolCall(ctx context.Context, id string, reason string) error
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/rpc"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
)
|
||||
|
||||
// client provides a JSON-RPC 2.0 client for communicating with the HumanLayer daemon
|
||||
@@ -266,7 +266,7 @@ func (c *client) ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueS
|
||||
}
|
||||
|
||||
// FetchApprovals fetches pending approvals from the daemon
|
||||
func (c *client) FetchApprovals(sessionID string) ([]approval.PendingApproval, error) {
|
||||
func (c *client) FetchApprovals(sessionID string) ([]*store.Approval, error) {
|
||||
req := rpc.FetchApprovalsRequest{
|
||||
SessionID: sessionID,
|
||||
}
|
||||
@@ -277,11 +277,11 @@ func (c *client) FetchApprovals(sessionID string) ([]approval.PendingApproval, e
|
||||
return resp.Approvals, nil
|
||||
}
|
||||
|
||||
// SendDecision sends a decision (approve/deny/respond) for an approval
|
||||
func (c *client) SendDecision(callID, approvalType, decision, comment string) error {
|
||||
// SendDecision sends a decision (approve/deny) for an approval
|
||||
func (c *client) SendDecision(approvalID, decision, comment string) error {
|
||||
req := rpc.SendDecisionRequest{
|
||||
CallID: callID,
|
||||
Type: approvalType,
|
||||
CallID: approvalID, // Using CallID for backward compatibility
|
||||
Type: "function_call", // Always function_call for local approvals
|
||||
Decision: decision,
|
||||
Comment: comment,
|
||||
}
|
||||
@@ -295,25 +295,17 @@ func (c *client) SendDecision(callID, approvalType, decision, comment string) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApproveFunctionCall approves a function call with an optional comment
|
||||
func (c *client) ApproveFunctionCall(callID, comment string) error {
|
||||
return c.SendDecision(callID, string(rpc.ApprovalTypeFunctionCall), string(rpc.DecisionApprove), comment)
|
||||
// ApproveToolCall approves a tool call with an optional comment
|
||||
func (c *client) ApproveToolCall(approvalID, comment string) error {
|
||||
return c.SendDecision(approvalID, "approve", comment)
|
||||
}
|
||||
|
||||
// DenyFunctionCall denies a function call with a required reason
|
||||
func (c *client) DenyFunctionCall(callID, reason string) error {
|
||||
// DenyToolCall denies a tool call with a required reason
|
||||
func (c *client) DenyToolCall(approvalID, reason string) error {
|
||||
if reason == "" {
|
||||
return fmt.Errorf("reason is required when denying a function call")
|
||||
return fmt.Errorf("reason is required when denying a tool call")
|
||||
}
|
||||
return c.SendDecision(callID, string(rpc.ApprovalTypeFunctionCall), string(rpc.DecisionDeny), reason)
|
||||
}
|
||||
|
||||
// RespondToHumanContact responds to a human contact request
|
||||
func (c *client) RespondToHumanContact(callID, response string) error {
|
||||
if response == "" {
|
||||
return fmt.Errorf("response is required for human contact")
|
||||
}
|
||||
return c.SendDecision(callID, string(rpc.ApprovalTypeHumanContact), string(rpc.DecisionRespond), response)
|
||||
return c.SendDecision(approvalID, "deny", reason)
|
||||
}
|
||||
|
||||
// GetConversation fetches the conversation history for a session
|
||||
|
||||
@@ -10,10 +10,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
"github.com/humanlayer/humanlayer/hld/rpc"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -143,30 +142,25 @@ func TestClient_FetchApprovals(t *testing.T) {
|
||||
server, socketPath := newMockRPCServer(t)
|
||||
defer server.stop()
|
||||
|
||||
// Create test approvals that match what the TUI expects
|
||||
testApprovals := []approval.PendingApproval{
|
||||
// Create test approvals with new local format
|
||||
testApprovals := []*store.Approval{
|
||||
{
|
||||
Type: "function_call",
|
||||
FunctionCall: &humanlayer.FunctionCall{
|
||||
CallID: "fc-123",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
Kwargs: map[string]interface{}{
|
||||
"arg": "value",
|
||||
},
|
||||
},
|
||||
RunID: "run-123",
|
||||
},
|
||||
ID: "local-123",
|
||||
RunID: "run-123",
|
||||
SessionID: "session-123",
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
CreatedAt: time.Now(),
|
||||
ToolName: "test_function",
|
||||
ToolInput: json.RawMessage(`{"arg": "value"}`),
|
||||
},
|
||||
{
|
||||
Type: "human_contact",
|
||||
HumanContact: &humanlayer.HumanContact{
|
||||
CallID: "hc-456",
|
||||
Spec: humanlayer.HumanContactSpec{
|
||||
Msg: "Need help with something",
|
||||
},
|
||||
RunID: "run-456",
|
||||
},
|
||||
ID: "local-456",
|
||||
RunID: "run-456",
|
||||
SessionID: "session-456",
|
||||
Status: store.ApprovalStatusLocalPending,
|
||||
CreatedAt: time.Now(),
|
||||
ToolName: "another_function",
|
||||
ToolInput: json.RawMessage(`{"msg": "test message"}`),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -187,15 +181,15 @@ func TestClient_FetchApprovals(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, approvals, 2)
|
||||
|
||||
// Verify function call
|
||||
assert.Equal(t, "function_call", approvals[0].Type)
|
||||
assert.NotNil(t, approvals[0].FunctionCall)
|
||||
assert.Equal(t, "fc-123", approvals[0].FunctionCall.CallID)
|
||||
// Verify first approval
|
||||
assert.Equal(t, "local-123", approvals[0].ID)
|
||||
assert.Equal(t, "test_function", approvals[0].ToolName)
|
||||
assert.Equal(t, store.ApprovalStatusLocalPending, approvals[0].Status)
|
||||
|
||||
// Verify human contact
|
||||
assert.Equal(t, "human_contact", approvals[1].Type)
|
||||
assert.NotNil(t, approvals[1].HumanContact)
|
||||
assert.Equal(t, "hc-456", approvals[1].HumanContact.CallID)
|
||||
// Verify second approval
|
||||
assert.Equal(t, "local-456", approvals[1].ID)
|
||||
assert.Equal(t, "another_function", approvals[1].ToolName)
|
||||
assert.Equal(t, store.ApprovalStatusLocalPending, approvals[1].Status)
|
||||
}
|
||||
|
||||
func TestClient_SendDecision(t *testing.T) {
|
||||
@@ -227,15 +221,18 @@ func TestClient_SendDecision(t *testing.T) {
|
||||
defer func() { _ = c.Close() }()
|
||||
|
||||
// Test approve
|
||||
err = c.SendDecision("test-123", "function_call", "approve", "looks good")
|
||||
err = c.SendDecision("test-123", "approve", "looks good")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test deny
|
||||
err = c.SendDecision("test-456", "function_call", "deny", "too risky")
|
||||
err = c.SendDecision("test-456", "deny", "too risky")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test respond
|
||||
err = c.SendDecision("test-789", "human_contact", "respond", "here is my response")
|
||||
// Test the convenience methods
|
||||
err = c.ApproveToolCall("test-789", "approved")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = c.DenyToolCall("test-890", "not allowed")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package client
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/rpc"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
)
|
||||
|
||||
// Client defines the interface for communicating with the HumanLayer daemon
|
||||
@@ -25,17 +25,14 @@ type Client interface {
|
||||
ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueSessionResponse, error)
|
||||
|
||||
// FetchApprovals fetches pending approvals from the daemon
|
||||
FetchApprovals(sessionID string) ([]approval.PendingApproval, error)
|
||||
FetchApprovals(sessionID string) ([]*store.Approval, error)
|
||||
|
||||
// SendDecision sends a decision (approve/deny/respond) for an approval
|
||||
SendDecision(callID, approvalType, decision, comment string) error
|
||||
// SendDecision sends a decision (approve/deny) for an approval
|
||||
SendDecision(approvalID, decision, comment string) error
|
||||
|
||||
// Type-safe approval methods for function calls
|
||||
ApproveFunctionCall(callID, comment string) error
|
||||
DenyFunctionCall(callID, reason string) error
|
||||
|
||||
// Type-safe approval methods for human contacts
|
||||
RespondToHumanContact(callID, response string) error
|
||||
// Type-safe approval methods
|
||||
ApproveToolCall(approvalID, comment string) error
|
||||
DenyToolCall(approvalID, reason string) error
|
||||
|
||||
// GetConversation fetches the conversation history for a session
|
||||
GetConversation(sessionID string) (*rpc.GetConversationResponse, error)
|
||||
|
||||
@@ -86,28 +86,10 @@ func New() (*Daemon, error) {
|
||||
return nil, fmt.Errorf("failed to create session manager: %w", err)
|
||||
}
|
||||
|
||||
// Create approval manager if API key is configured
|
||||
var approvalManager approval.Manager
|
||||
if cfg.APIKey != "" {
|
||||
slog.Info("creating approval manager", "api_base_url", cfg.APIBaseURL)
|
||||
approvalCfg := approval.Config{
|
||||
APIKey: cfg.APIKey,
|
||||
BaseURL: cfg.APIBaseURL,
|
||||
// Use defaults for now, could add to daemon config later
|
||||
}
|
||||
approvalManager, err = approval.NewManager(approvalCfg, eventBus, conversationStore)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create approval manager: %w", err)
|
||||
}
|
||||
slog.Debug("approval manager created successfully")
|
||||
} else {
|
||||
slog.Warn("no API key configured, approval features disabled")
|
||||
}
|
||||
|
||||
// Set approval reconciler on session manager if approval manager exists
|
||||
if approvalManager != nil {
|
||||
sessionManager.SetApprovalReconciler(approvalManager)
|
||||
}
|
||||
// Always create local approval manager
|
||||
slog.Info("creating local approval manager")
|
||||
approvalManager := approval.NewManager(conversationStore, eventBus)
|
||||
slog.Debug("local approval manager created successfully")
|
||||
|
||||
return &Daemon{
|
||||
config: cfg,
|
||||
@@ -172,23 +154,10 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||
sessionHandlers := rpc.NewSessionHandlers(d.sessions, d.store)
|
||||
sessionHandlers.Register(d.rpcServer)
|
||||
|
||||
// Always register approval handlers (even without API key)
|
||||
// Register local approval handlers
|
||||
approvalHandlers := rpc.NewApprovalHandlers(d.approvals, d.sessions)
|
||||
approvalHandlers.Register(d.rpcServer)
|
||||
|
||||
// Start approval polling if approval manager is available
|
||||
if d.approvals != nil {
|
||||
if err := d.approvals.Start(ctx); err != nil {
|
||||
_ = listener.Close()
|
||||
return fmt.Errorf("failed to start approval poller: %w", err)
|
||||
}
|
||||
defer d.approvals.Stop()
|
||||
|
||||
slog.Info("approval polling started")
|
||||
} else {
|
||||
slog.Warn("approval manager not configured (no API key)")
|
||||
}
|
||||
|
||||
slog.Info("daemon started", "socket", d.socketPath)
|
||||
|
||||
// Accept connections until context is cancelled
|
||||
|
||||
@@ -7,118 +7,22 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/config"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
"github.com/humanlayer/humanlayer/hld/rpc"
|
||||
"github.com/humanlayer/humanlayer/hld/session"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
// mockAPIClient is a test implementation of the HumanLayer API client
|
||||
// This simulates a backend API for integration testing
|
||||
type mockAPIClient struct {
|
||||
functionCalls []humanlayer.FunctionCall
|
||||
humanContacts []humanlayer.HumanContact
|
||||
decisions map[string]string // call_id -> decision
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockAPIClient() *mockAPIClient {
|
||||
return &mockAPIClient{
|
||||
decisions: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAPIClient) GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Return only function calls that haven't been decided
|
||||
var pending []humanlayer.FunctionCall
|
||||
for _, fc := range m.functionCalls {
|
||||
if _, decided := m.decisions[fc.CallID]; !decided {
|
||||
pending = append(pending, fc)
|
||||
}
|
||||
}
|
||||
return pending, nil
|
||||
}
|
||||
|
||||
func (m *mockAPIClient) GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Return only contacts that haven't been decided
|
||||
var pending []humanlayer.HumanContact
|
||||
for _, hc := range m.humanContacts {
|
||||
if _, decided := m.decisions[hc.CallID]; !decided {
|
||||
pending = append(pending, hc)
|
||||
}
|
||||
}
|
||||
return pending, nil
|
||||
}
|
||||
|
||||
func (m *mockAPIClient) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.decisions[callID] = "approved"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAPIClient) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.decisions[callID] = "denied"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAPIClient) RespondToHumanContact(ctx context.Context, callID string, response string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.decisions[callID] = "responded"
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDaemonApprovalIntegration(t *testing.T) {
|
||||
// Create test socket path
|
||||
socketPath := testutil.SocketPath(t, "daemon-approval-test")
|
||||
|
||||
// Create mock API client with test data
|
||||
mockClient := newMockAPIClient()
|
||||
mockClient.functionCalls = []humanlayer.FunctionCall{
|
||||
{
|
||||
CallID: "fc-1",
|
||||
RunID: "test-run-1",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "dangerous_function",
|
||||
Kwargs: map[string]interface{}{
|
||||
"action": "delete_all",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
CallID: "fc-2",
|
||||
RunID: "test-run-2",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "safe_function",
|
||||
},
|
||||
},
|
||||
}
|
||||
mockClient.humanContacts = []humanlayer.HumanContact{
|
||||
{
|
||||
CallID: "hc-1",
|
||||
RunID: "test-run-1",
|
||||
Spec: humanlayer.HumanContactSpec{
|
||||
Msg: "Need human help",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create a minimal in-memory store for testing
|
||||
testStore, err := store.NewSQLiteStore(":memory:")
|
||||
if err != nil {
|
||||
@@ -126,16 +30,16 @@ func TestDaemonApprovalIntegration(t *testing.T) {
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Create real approval components for integration testing
|
||||
approvalStore := approval.NewMemoryStore()
|
||||
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, nil)
|
||||
// Create event bus for the approval manager
|
||||
eventBus := bus.NewEventBus()
|
||||
|
||||
// We need to manually construct the manager with our test client
|
||||
approvalManager := &approval.DefaultManager{
|
||||
Client: mockClient,
|
||||
Store: approvalStore,
|
||||
Poller: poller,
|
||||
ConversationStore: testStore,
|
||||
// Create local approval manager
|
||||
approvalManager := approval.NewManager(testStore, eventBus)
|
||||
|
||||
// Create session manager
|
||||
sessionManager, err := session.NewManager(eventBus, testStore)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create test daemon with approval manager
|
||||
@@ -146,16 +50,11 @@ func TestDaemonApprovalIntegration(t *testing.T) {
|
||||
},
|
||||
socketPath: socketPath,
|
||||
approvals: approvalManager,
|
||||
sessions: sessionManager,
|
||||
eventBus: eventBus,
|
||||
store: testStore,
|
||||
}
|
||||
|
||||
// Create session manager (we don't need real sessions for this test)
|
||||
|
||||
sessionManager, err := session.NewManager(nil, testStore)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create session manager: %v", err)
|
||||
}
|
||||
d.sessions = sessionManager
|
||||
|
||||
// Start daemon
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -178,82 +77,158 @@ func TestDaemonApprovalIntegration(t *testing.T) {
|
||||
// Create RPC client
|
||||
client := &rpcClient{conn: conn}
|
||||
|
||||
// Wait for poller to fetch initial data
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Test scenario: Create local approvals and test RPC operations
|
||||
|
||||
// Test 1: Fetch all approvals
|
||||
// First, create test sessions and approvals in the database
|
||||
ctx2 := context.Background()
|
||||
|
||||
// Create test sessions
|
||||
session1 := &store.Session{
|
||||
ID: "test-session-1",
|
||||
RunID: "test-run-1",
|
||||
Query: "Test query 1",
|
||||
Status: store.SessionStatusRunning,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
}
|
||||
if err := testStore.CreateSession(ctx2, session1); err != nil {
|
||||
t.Fatalf("failed to create session 1: %v", err)
|
||||
}
|
||||
|
||||
session2 := &store.Session{
|
||||
ID: "test-session-2",
|
||||
RunID: "test-run-2",
|
||||
Query: "Test query 2",
|
||||
Status: store.SessionStatusRunning,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
}
|
||||
if err := testStore.CreateSession(ctx2, session2); err != nil {
|
||||
t.Fatalf("failed to create session 2: %v", err)
|
||||
}
|
||||
|
||||
// Create approvals via the RPC interface
|
||||
var createResp rpc.CreateApprovalResponse
|
||||
|
||||
// Approval 1
|
||||
err = client.call("createApproval", rpc.CreateApprovalRequest{
|
||||
RunID: "test-run-1",
|
||||
ToolName: "dangerous_function",
|
||||
ToolInput: json.RawMessage(`{"action": "delete_all"}`),
|
||||
}, &createResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create approval 1: %v", err)
|
||||
}
|
||||
approval1ID := createResp.ApprovalID
|
||||
|
||||
// Approval 2
|
||||
err = client.call("createApproval", rpc.CreateApprovalRequest{
|
||||
RunID: "test-run-2",
|
||||
ToolName: "safe_function",
|
||||
ToolInput: json.RawMessage(`{"action": "read_only"}`),
|
||||
}, &createResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create approval 2: %v", err)
|
||||
}
|
||||
approval2ID := createResp.ApprovalID
|
||||
|
||||
// Test 1: Fetch all approvals (should be empty without session filter)
|
||||
var fetchResp rpc.FetchApprovalsResponse
|
||||
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{}, &fetchResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch approvals: %v", err)
|
||||
}
|
||||
|
||||
if len(fetchResp.Approvals) != 3 {
|
||||
t.Errorf("expected 3 approvals, got %d", len(fetchResp.Approvals))
|
||||
if len(fetchResp.Approvals) != 0 {
|
||||
t.Errorf("expected 0 approvals without session filter, got %d", len(fetchResp.Approvals))
|
||||
}
|
||||
|
||||
// Test 2: Approve a function call
|
||||
// Test 2: Fetch approvals for session 1
|
||||
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{
|
||||
SessionID: "test-session-1",
|
||||
}, &fetchResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch approvals for session 1: %v", err)
|
||||
}
|
||||
|
||||
if len(fetchResp.Approvals) != 1 {
|
||||
t.Errorf("expected 1 approval for session 1, got %d", len(fetchResp.Approvals))
|
||||
}
|
||||
|
||||
// Test 3: Approve a function call
|
||||
var decisionResp rpc.SendDecisionResponse
|
||||
err = client.call("sendDecision", rpc.SendDecisionRequest{
|
||||
CallID: "fc-1",
|
||||
CallID: approval1ID,
|
||||
Type: "function_call",
|
||||
Decision: "approve",
|
||||
Comment: "Looks good",
|
||||
}, &decisionResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send decision: %v", err)
|
||||
t.Fatalf("failed to send approval decision: %v", err)
|
||||
}
|
||||
|
||||
if !decisionResp.Success {
|
||||
t.Errorf("expected success, got error: %s", decisionResp.Error)
|
||||
}
|
||||
|
||||
// Verify the decision was recorded
|
||||
if mockClient.decisions["fc-1"] != "approved" {
|
||||
t.Errorf("expected fc-1 to be approved, got %s", mockClient.decisions["fc-1"])
|
||||
}
|
||||
|
||||
// Test 3: Deny a function call
|
||||
// Test 4: Deny a function call
|
||||
err = client.call("sendDecision", rpc.SendDecisionRequest{
|
||||
CallID: "fc-2",
|
||||
CallID: approval2ID,
|
||||
Type: "function_call",
|
||||
Decision: "deny",
|
||||
Comment: "Too risky",
|
||||
}, &decisionResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send decision: %v", err)
|
||||
t.Fatalf("failed to send deny decision: %v", err)
|
||||
}
|
||||
|
||||
if !decisionResp.Success {
|
||||
t.Errorf("expected success, got error: %s", decisionResp.Error)
|
||||
}
|
||||
|
||||
// Test 4: Respond to human contact
|
||||
err = client.call("sendDecision", rpc.SendDecisionRequest{
|
||||
CallID: "hc-1",
|
||||
Type: "human_contact",
|
||||
Decision: "respond",
|
||||
Comment: "Here's the help you need",
|
||||
}, &decisionResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send decision: %v", err)
|
||||
}
|
||||
|
||||
if !decisionResp.Success {
|
||||
t.Errorf("expected success, got error: %s", decisionResp.Error)
|
||||
}
|
||||
|
||||
// Wait for next poll cycle
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Test 5: Verify approvals are no longer pending
|
||||
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{}, &fetchResp)
|
||||
err = client.call("fetchApprovals", rpc.FetchApprovalsRequest{
|
||||
SessionID: "test-session-1",
|
||||
}, &fetchResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch approvals: %v", err)
|
||||
}
|
||||
|
||||
if len(fetchResp.Approvals) != 0 {
|
||||
t.Errorf("expected 0 pending approvals after decisions, got %d", len(fetchResp.Approvals))
|
||||
t.Errorf("expected 0 pending approvals for session 1 after approval, got %d", len(fetchResp.Approvals))
|
||||
}
|
||||
|
||||
// Test 6: Try to approve non-existent approval
|
||||
err = client.call("sendDecision", rpc.SendDecisionRequest{
|
||||
CallID: "non-existent",
|
||||
Type: "function_call",
|
||||
Decision: "approve",
|
||||
Comment: "Should fail",
|
||||
}, &decisionResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send decision: %v", err)
|
||||
}
|
||||
|
||||
if decisionResp.Success {
|
||||
t.Error("expected failure for non-existent approval")
|
||||
}
|
||||
|
||||
// Test 7: Human contact is no longer supported
|
||||
err = client.call("sendDecision", rpc.SendDecisionRequest{
|
||||
CallID: "some-id",
|
||||
Type: "human_contact",
|
||||
Decision: "respond",
|
||||
Comment: "Should fail",
|
||||
}, &decisionResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send decision: %v", err)
|
||||
}
|
||||
|
||||
if decisionResp.Success {
|
||||
t.Error("expected failure for human contact type")
|
||||
}
|
||||
if decisionResp.Error != "human contact approvals are no longer supported" {
|
||||
t.Errorf("expected specific error for human contact, got: %s", decisionResp.Error)
|
||||
}
|
||||
|
||||
// Shutdown daemon
|
||||
|
||||
@@ -4,7 +4,8 @@ package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -14,57 +15,30 @@ import (
|
||||
"github.com/humanlayer/humanlayer/hld/client"
|
||||
"github.com/humanlayer/humanlayer/hld/config"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
"github.com/humanlayer/humanlayer/hld/rpc"
|
||||
"github.com/humanlayer/humanlayer/hld/session"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
// TestSessionStateTransitionsIntegration tests the full flow of session state changes
|
||||
// when approvals are created and resolved
|
||||
func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
func TestDaemonSessionStateIntegration(t *testing.T) {
|
||||
// This test verifies that session state transitions work correctly with local approvals
|
||||
|
||||
// Create temporary socket path for test
|
||||
socketPath := testutil.SocketPath(t, "session-state")
|
||||
// Create test socket
|
||||
socketPath := testutil.SocketPath(t, "test")
|
||||
|
||||
// Use a temporary database file instead of :memory: to ensure all connections
|
||||
// access the same database (in-memory databases are unique per connection)
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "test.db")
|
||||
|
||||
// Create the store
|
||||
// Create test database
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
testStore, err := store.NewSQLiteStore(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test store: %v", err)
|
||||
t.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Set environment variables to ensure consistent test behavior
|
||||
t.Setenv("HUMANLAYER_DATABASE_PATH", dbPath)
|
||||
t.Setenv("HUMANLAYER_API_KEY", "test-key")
|
||||
|
||||
// Create event bus
|
||||
eventBus := bus.NewEventBus()
|
||||
|
||||
// Create mock API client
|
||||
mockClient := &mockSessionStateAPIClient{
|
||||
functionCalls: make(map[string]*humanlayer.FunctionCall),
|
||||
}
|
||||
|
||||
// Create real approval components
|
||||
approvalStore := approval.NewMemoryStore()
|
||||
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, eventBus)
|
||||
|
||||
// Create approval manager
|
||||
approvalManager := &approval.DefaultManager{
|
||||
Client: mockClient,
|
||||
Store: approvalStore,
|
||||
Poller: poller,
|
||||
EventBus: eventBus,
|
||||
ConversationStore: testStore,
|
||||
}
|
||||
// Create local approval manager
|
||||
approvalManager := approval.NewManager(testStore, eventBus)
|
||||
|
||||
// Create session manager
|
||||
sessionManager, err := session.NewManager(eventBus, testStore)
|
||||
@@ -125,36 +99,43 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// 2. Add a tool call that needs approval
|
||||
toolCall := &store.ConversationEvent{
|
||||
SessionID: sessionID,
|
||||
// 2. Add a tool call event that would need approval
|
||||
toolCallEvent := &store.ConversationEvent{
|
||||
SessionID: sessionID, // Use the actual session ID from sessions table
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
Sequence: 1,
|
||||
EventType: store.EventTypeToolCall,
|
||||
ToolID: "tool-001",
|
||||
ToolName: "dangerous_function",
|
||||
ToolID: "tool-call-123",
|
||||
ToolInputJSON: `{"action": "delete_all"}`,
|
||||
ApprovalStatus: "", // No approval yet
|
||||
ApprovalID: "",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := testStore.AddConversationEvent(ctx, toolCall); err != nil {
|
||||
if err := testStore.AddConversationEvent(ctx, toolCallEvent); err != nil {
|
||||
t.Fatalf("failed to add tool call: %v", err)
|
||||
}
|
||||
|
||||
// 3. Simulate an approval coming from HumanLayer API
|
||||
approvalID := "approval-001"
|
||||
mockClient.AddFunctionCall(humanlayer.FunctionCall{
|
||||
CallID: approvalID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "dangerous_function",
|
||||
Kwargs: map[string]interface{}{
|
||||
"action": "delete_all",
|
||||
},
|
||||
},
|
||||
})
|
||||
// 3. Create an approval via RPC (simulating MCP creating approval)
|
||||
// We need to use the rpcClient directly for this test
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to daemon: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 4. Wait for poller to pick up the approval and correlate it
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
rpcClient := &rpcClient{conn: conn}
|
||||
var createResp rpc.CreateApprovalResponse
|
||||
if err := rpcClient.call("createApproval", rpc.CreateApprovalRequest{
|
||||
RunID: runID,
|
||||
ToolName: "dangerous_function",
|
||||
ToolInput: json.RawMessage(`{"action": "delete_all"}`),
|
||||
}, &createResp); err != nil {
|
||||
t.Fatalf("failed to create approval: %v", err)
|
||||
}
|
||||
approvalID := createResp.ApprovalID
|
||||
|
||||
// 4. Give time for event bus to propagate and correlation to happen
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 5. Check that session status changed to waiting_input
|
||||
updatedSession, err := testStore.GetSession(ctx, sessionID)
|
||||
@@ -165,7 +146,16 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
t.Errorf("expected session status to be waiting_input, got %s", updatedSession.Status)
|
||||
}
|
||||
|
||||
// 6. Check that approval was correlated
|
||||
// 6. Check that approval was created in database
|
||||
approval, err := testStore.GetApproval(ctx, approvalID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get approval: %v", err)
|
||||
}
|
||||
if approval.Status != store.ApprovalStatusLocalPending {
|
||||
t.Errorf("expected approval status to be pending, got %s", approval.Status)
|
||||
}
|
||||
|
||||
// 7. Check that approval was correlated with tool call
|
||||
conversation, err := testStore.GetConversation(ctx, claudeSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get conversation: %v", err)
|
||||
@@ -189,15 +179,15 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
t.Errorf("expected approval ID to be %s, got %s", approvalID, correlatedEvent.ApprovalID)
|
||||
}
|
||||
|
||||
// 7. Approve the function call via client
|
||||
if err := c.SendDecision(approvalID, "function_call", "approve", "Approved for testing"); err != nil {
|
||||
// 8. Approve the function call via client
|
||||
if err := c.SendDecision(approvalID, "approve", "Approved for testing"); err != nil {
|
||||
t.Fatalf("failed to send approval: %v", err)
|
||||
}
|
||||
|
||||
// Give time for status update
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 8. Check that session status changed back to running
|
||||
// 9. Check that session status changed back to running
|
||||
finalSession, err := testStore.GetSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get final session: %v", err)
|
||||
@@ -206,7 +196,16 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
t.Errorf("expected session status to be running after approval, got %s", finalSession.Status)
|
||||
}
|
||||
|
||||
// 9. Check that approval status was updated
|
||||
// 10. Check that approval status was updated
|
||||
finalApproval, err := testStore.GetApproval(ctx, approvalID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get final approval: %v", err)
|
||||
}
|
||||
if finalApproval.Status != store.ApprovalStatusLocalApproved {
|
||||
t.Errorf("expected approval status to be approved, got %s", finalApproval.Status)
|
||||
}
|
||||
|
||||
// 11. Check that approval status was updated in conversation
|
||||
finalConversation, err := testStore.GetConversation(ctx, claudeSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get final conversation: %v", err)
|
||||
@@ -239,60 +238,3 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
t.Error("daemon did not shut down in time")
|
||||
}
|
||||
}
|
||||
|
||||
// Mock API client for session state testing
|
||||
type mockSessionStateAPIClient struct {
|
||||
functionCalls map[string]*humanlayer.FunctionCall
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) AddFunctionCall(fc humanlayer.FunctionCall) {
|
||||
m.functionCalls[fc.CallID] = &fc
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error) {
|
||||
var result []humanlayer.FunctionCall
|
||||
for _, fc := range m.functionCalls {
|
||||
if fc.Status == nil || fc.Status.RespondedAt == nil {
|
||||
result = append(result, *fc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error) {
|
||||
return []humanlayer.HumanContact{}, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
|
||||
if fc, ok := m.functionCalls[callID]; ok {
|
||||
if fc.Status == nil {
|
||||
fc.Status = &humanlayer.FunctionCallStatus{}
|
||||
}
|
||||
now := humanlayer.CustomTime{Time: time.Now()}
|
||||
fc.Status.RespondedAt = &now
|
||||
approved := true
|
||||
fc.Status.Approved = &approved
|
||||
fc.Status.Comment = comment
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("function call not found: %s", callID)
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
|
||||
if fc, ok := m.functionCalls[callID]; ok {
|
||||
if fc.Status == nil {
|
||||
fc.Status = &humanlayer.FunctionCallStatus{}
|
||||
}
|
||||
now := humanlayer.CustomTime{Time: time.Now()}
|
||||
fc.Status.RespondedAt = &now
|
||||
approved := false
|
||||
fc.Status.Approved = &approved
|
||||
fc.Status.Comment = reason
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("function call not found: %s", callID)
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) RespondToHumanContact(ctx context.Context, callID string, response string) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -7,15 +7,16 @@ import (
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/session"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
)
|
||||
|
||||
// ApprovalHandlers provides RPC handlers for approval management
|
||||
// ApprovalHandlers provides RPC handlers for local approval management
|
||||
type ApprovalHandlers struct {
|
||||
approvals approval.Manager
|
||||
sessions session.SessionManager
|
||||
}
|
||||
|
||||
// NewApprovalHandlers creates new approval RPC handlers
|
||||
// NewApprovalHandlers creates new local approval RPC handlers
|
||||
func NewApprovalHandlers(approvals approval.Manager, sessions session.SessionManager) *ApprovalHandlers {
|
||||
return &ApprovalHandlers{
|
||||
approvals: approvals,
|
||||
@@ -23,6 +24,47 @@ func NewApprovalHandlers(approvals approval.Manager, sessions session.SessionMan
|
||||
}
|
||||
}
|
||||
|
||||
// CreateApprovalRequest is the request for creating a local approval
|
||||
type CreateApprovalRequest struct {
|
||||
RunID string `json:"run_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolInput json.RawMessage `json:"tool_input"`
|
||||
}
|
||||
|
||||
// CreateApprovalResponse is the response for creating a local approval
|
||||
type CreateApprovalResponse struct {
|
||||
ApprovalID string `json:"approval_id"`
|
||||
}
|
||||
|
||||
// HandleCreateApproval handles the CreateApproval RPC method
|
||||
func (h *ApprovalHandlers) HandleCreateApproval(ctx context.Context, params json.RawMessage) (interface{}, error) {
|
||||
var req CreateApprovalRequest
|
||||
if err := json.Unmarshal(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("invalid request: %w", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.RunID == "" {
|
||||
return nil, fmt.Errorf("run_id is required")
|
||||
}
|
||||
if req.ToolName == "" {
|
||||
return nil, fmt.Errorf("tool_name is required")
|
||||
}
|
||||
if req.ToolInput == nil {
|
||||
return nil, fmt.Errorf("tool_input is required")
|
||||
}
|
||||
|
||||
// Create the approval
|
||||
approvalID, err := h.approvals.CreateApproval(ctx, req.RunID, req.ToolName, req.ToolInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create approval: %w", err)
|
||||
}
|
||||
|
||||
return &CreateApprovalResponse{
|
||||
ApprovalID: approvalID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FetchApprovalsRequest is the request for fetching approvals
|
||||
type FetchApprovalsRequest struct {
|
||||
SessionID string `json:"session_id,omitempty"` // Optional filter by session
|
||||
@@ -30,16 +72,11 @@ type FetchApprovalsRequest struct {
|
||||
|
||||
// FetchApprovalsResponse is the response for fetching approvals
|
||||
type FetchApprovalsResponse struct {
|
||||
Approvals []approval.PendingApproval `json:"approvals"`
|
||||
Approvals []*store.Approval `json:"approvals"`
|
||||
}
|
||||
|
||||
// HandleFetchApprovals handles the FetchApprovals RPC method
|
||||
func (h *ApprovalHandlers) HandleFetchApprovals(ctx context.Context, params json.RawMessage) (interface{}, error) {
|
||||
// Check if approval manager is configured
|
||||
if h.approvals == nil {
|
||||
return nil, fmt.Errorf("approval features not available: no API key configured")
|
||||
}
|
||||
|
||||
var req FetchApprovalsRequest
|
||||
if params != nil {
|
||||
if err := json.Unmarshal(params, &req); err != nil {
|
||||
@@ -47,27 +84,17 @@ func (h *ApprovalHandlers) HandleFetchApprovals(ctx context.Context, params json
|
||||
}
|
||||
}
|
||||
|
||||
var approvals []approval.PendingApproval
|
||||
var err error
|
||||
// If no session ID provided, return empty list
|
||||
if req.SessionID == "" {
|
||||
return &FetchApprovalsResponse{
|
||||
Approvals: []*store.Approval{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.SessionID != "" {
|
||||
// Get the session to find its run_id
|
||||
sessionInfo, err := h.sessions.GetSessionInfo(req.SessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session not found: %w", err)
|
||||
}
|
||||
|
||||
// Get approvals by run_id
|
||||
approvals, err = h.approvals.GetPendingApprovalsByRunID(sessionInfo.RunID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch approvals: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Get all pending approvals
|
||||
approvals, err = h.approvals.GetPendingApprovals("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch approvals: %w", err)
|
||||
}
|
||||
// Get approvals for the session
|
||||
approvals, err := h.approvals.GetPendingApprovals(ctx, req.SessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch approvals: %w", err)
|
||||
}
|
||||
|
||||
return &FetchApprovalsResponse{
|
||||
@@ -77,9 +104,9 @@ func (h *ApprovalHandlers) HandleFetchApprovals(ctx context.Context, params json
|
||||
|
||||
// SendDecisionRequest is the request for sending a decision
|
||||
type SendDecisionRequest struct {
|
||||
CallID string `json:"call_id"`
|
||||
Type string `json:"type"` // "function_call" or "human_contact"
|
||||
Decision string `json:"decision"` // "approve", "deny", or "respond"
|
||||
CallID string `json:"call_id"` // Actually approval ID, but keeping name for compatibility
|
||||
Type string `json:"type"` // Ignored for local approvals
|
||||
Decision string `json:"decision"` // "approve" or "deny"
|
||||
Comment string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
@@ -91,11 +118,6 @@ type SendDecisionResponse struct {
|
||||
|
||||
// HandleSendDecision handles the SendDecision RPC method
|
||||
func (h *ApprovalHandlers) HandleSendDecision(ctx context.Context, params json.RawMessage) (interface{}, error) {
|
||||
// Check if approval manager is configured
|
||||
if h.approvals == nil {
|
||||
return nil, fmt.Errorf("approval features not available: no API key configured")
|
||||
}
|
||||
|
||||
var req SendDecisionRequest
|
||||
if err := json.Unmarshal(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("invalid request: %w", err)
|
||||
@@ -105,38 +127,30 @@ func (h *ApprovalHandlers) HandleSendDecision(ctx context.Context, params json.R
|
||||
if req.CallID == "" {
|
||||
return nil, fmt.Errorf("call_id is required")
|
||||
}
|
||||
if req.Type == "" {
|
||||
return nil, fmt.Errorf("type is required")
|
||||
}
|
||||
if req.Decision == "" {
|
||||
return nil, fmt.Errorf("decision is required")
|
||||
}
|
||||
|
||||
// Check if this is a human contact type (no longer supported)
|
||||
if req.Type == "human_contact" {
|
||||
return &SendDecisionResponse{
|
||||
Success: false,
|
||||
Error: "human contact approvals are no longer supported",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
switch req.Type {
|
||||
case "function_call":
|
||||
switch req.Decision {
|
||||
case "approve":
|
||||
err = h.approvals.ApproveFunctionCall(ctx, req.CallID, req.Comment)
|
||||
case "deny":
|
||||
if req.Comment == "" {
|
||||
return nil, fmt.Errorf("comment is required for denial")
|
||||
}
|
||||
err = h.approvals.DenyFunctionCall(ctx, req.CallID, req.Comment)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid decision for function_call: %s", req.Decision)
|
||||
}
|
||||
case "human_contact":
|
||||
if req.Decision != "respond" {
|
||||
return nil, fmt.Errorf("invalid decision for human_contact: %s", req.Decision)
|
||||
}
|
||||
switch req.Decision {
|
||||
case "approve":
|
||||
err = h.approvals.ApproveToolCall(ctx, req.CallID, req.Comment)
|
||||
case "deny":
|
||||
if req.Comment == "" {
|
||||
return nil, fmt.Errorf("comment is required for human contact response")
|
||||
return nil, fmt.Errorf("comment is required for denial")
|
||||
}
|
||||
err = h.approvals.RespondToHumanContact(ctx, req.CallID, req.Comment)
|
||||
err = h.approvals.DenyToolCall(ctx, req.CallID, req.Comment)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid type: %s", req.Type)
|
||||
return nil, fmt.Errorf("invalid decision: %s (must be 'approve' or 'deny')", req.Decision)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -151,8 +165,43 @@ func (h *ApprovalHandlers) HandleSendDecision(ctx context.Context, params json.R
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Register registers all approval handlers with the RPC server
|
||||
// GetApprovalRequest is the request for getting a specific approval
|
||||
type GetApprovalRequest struct {
|
||||
ApprovalID string `json:"approval_id"`
|
||||
}
|
||||
|
||||
// GetApprovalResponse is the response for getting a specific approval
|
||||
type GetApprovalResponse struct {
|
||||
Approval *store.Approval `json:"approval"`
|
||||
}
|
||||
|
||||
// HandleGetApproval handles the GetApproval RPC method
|
||||
func (h *ApprovalHandlers) HandleGetApproval(ctx context.Context, params json.RawMessage) (interface{}, error) {
|
||||
var req GetApprovalRequest
|
||||
if err := json.Unmarshal(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("invalid request: %w", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.ApprovalID == "" {
|
||||
return nil, fmt.Errorf("approval_id is required")
|
||||
}
|
||||
|
||||
// Get the approval
|
||||
approval, err := h.approvals.GetApproval(ctx, req.ApprovalID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get approval: %w", err)
|
||||
}
|
||||
|
||||
return &GetApprovalResponse{
|
||||
Approval: approval,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Register registers all local approval handlers with the RPC server
|
||||
func (h *ApprovalHandlers) Register(server *Server) {
|
||||
server.Register("createApproval", h.HandleCreateApproval)
|
||||
server.Register("fetchApprovals", h.HandleFetchApprovals)
|
||||
server.Register("getApproval", h.HandleGetApproval)
|
||||
server.Register("sendDecision", h.HandleSendDecision)
|
||||
}
|
||||
|
||||
@@ -173,6 +173,28 @@ func (s *SQLiteStore) initSchema() error {
|
||||
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
description TEXT
|
||||
);
|
||||
|
||||
-- Approvals table for local approvals
|
||||
CREATE TABLE IF NOT EXISTS approvals (
|
||||
id TEXT PRIMARY KEY,
|
||||
run_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK (status IN ('pending', 'approved', 'denied')),
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
responded_at DATETIME,
|
||||
|
||||
-- Tool approval fields
|
||||
tool_name TEXT NOT NULL,
|
||||
tool_input TEXT NOT NULL, -- JSON
|
||||
|
||||
-- Response fields
|
||||
comment TEXT, -- For denial reasons or approval notes
|
||||
|
||||
FOREIGN KEY (session_id) REFERENCES sessions(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_pending ON approvals(status) WHERE status = 'pending';
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_session ON approvals(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_run_id ON approvals(run_id);
|
||||
`
|
||||
|
||||
if _, err := s.db.Exec(schema); err != nil {
|
||||
@@ -258,6 +280,49 @@ func (s *SQLiteStore) applyMigrations() error {
|
||||
slog.Info("Migration 3 applied successfully")
|
||||
}
|
||||
|
||||
// Migration 4: Add approvals table for local approvals
|
||||
if currentVersion < 4 {
|
||||
slog.Info("Applying migration 4: Add approvals table for local approvals")
|
||||
|
||||
// Create the approvals table
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS approvals (
|
||||
id TEXT PRIMARY KEY,
|
||||
run_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL CHECK (status IN ('pending', 'approved', 'denied')),
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
responded_at DATETIME,
|
||||
|
||||
-- Tool approval fields
|
||||
tool_name TEXT NOT NULL,
|
||||
tool_input TEXT NOT NULL, -- JSON
|
||||
|
||||
-- Response fields
|
||||
comment TEXT, -- For denial reasons or approval notes
|
||||
|
||||
FOREIGN KEY (session_id) REFERENCES sessions(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_pending ON approvals(status) WHERE status = 'pending';
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_session ON approvals(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_run_id ON approvals(run_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create approvals table: %w", err)
|
||||
}
|
||||
|
||||
// Record migration
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO schema_version (version, description)
|
||||
VALUES (4, 'Add approvals table for local approvals')
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration 4: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("Migration 4 applied successfully")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1104,6 +1169,154 @@ func (s *SQLiteStore) StoreRawEvent(ctx context.Context, sessionID string, event
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateApproval creates a new approval
|
||||
func (s *SQLiteStore) CreateApproval(ctx context.Context, approval *Approval) error {
|
||||
// Validate status
|
||||
if !approval.Status.IsValid() {
|
||||
return fmt.Errorf("invalid approval status: %s", approval.Status)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO approvals (
|
||||
id, run_id, session_id, status, created_at,
|
||||
tool_name, tool_input
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query,
|
||||
approval.ID, approval.RunID, approval.SessionID, approval.Status.String(), approval.CreatedAt,
|
||||
approval.ToolName, string(approval.ToolInput),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create approval: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetApproval retrieves an approval by ID
|
||||
func (s *SQLiteStore) GetApproval(ctx context.Context, id string) (*Approval, error) {
|
||||
query := `
|
||||
SELECT id, run_id, session_id, status, created_at, responded_at,
|
||||
tool_name, tool_input, comment
|
||||
FROM approvals WHERE id = ?
|
||||
`
|
||||
|
||||
var approval Approval
|
||||
var respondedAt sql.NullTime
|
||||
var comment sql.NullString
|
||||
var statusStr string
|
||||
var toolInputStr string
|
||||
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&approval.ID, &approval.RunID, &approval.SessionID, &statusStr,
|
||||
&approval.CreatedAt, &respondedAt,
|
||||
&approval.ToolName, &toolInputStr, &comment,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("approval not found: %s", id)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get approval: %w", err)
|
||||
}
|
||||
|
||||
// Convert status string to ApprovalStatus
|
||||
approval.Status = ApprovalStatus(statusStr)
|
||||
if !approval.Status.IsValid() {
|
||||
return nil, fmt.Errorf("invalid approval status in database: %s", statusStr)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if respondedAt.Valid {
|
||||
approval.RespondedAt = &respondedAt.Time
|
||||
}
|
||||
approval.Comment = comment.String
|
||||
approval.ToolInput = json.RawMessage(toolInputStr)
|
||||
|
||||
return &approval, nil
|
||||
}
|
||||
|
||||
// GetPendingApprovals retrieves all pending approvals for a session
|
||||
func (s *SQLiteStore) GetPendingApprovals(ctx context.Context, sessionID string) ([]*Approval, error) {
|
||||
query := `
|
||||
SELECT id, run_id, session_id, status, created_at, responded_at,
|
||||
tool_name, tool_input, comment
|
||||
FROM approvals
|
||||
WHERE session_id = ? AND status = ?
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, sessionID, ApprovalStatusLocalPending.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get pending approvals: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var approvals []*Approval
|
||||
for rows.Next() {
|
||||
var approval Approval
|
||||
var respondedAt sql.NullTime
|
||||
var comment sql.NullString
|
||||
var statusStr string
|
||||
var toolInputStr string
|
||||
|
||||
err := rows.Scan(
|
||||
&approval.ID, &approval.RunID, &approval.SessionID, &statusStr,
|
||||
&approval.CreatedAt, &respondedAt,
|
||||
&approval.ToolName, &toolInputStr, &comment,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan approval: %w", err)
|
||||
}
|
||||
|
||||
// Convert status string to ApprovalStatus
|
||||
approval.Status = ApprovalStatus(statusStr)
|
||||
if !approval.Status.IsValid() {
|
||||
return nil, fmt.Errorf("invalid approval status in database: %s", statusStr)
|
||||
}
|
||||
|
||||
// Handle nullable fields
|
||||
if respondedAt.Valid {
|
||||
approval.RespondedAt = &respondedAt.Time
|
||||
}
|
||||
approval.Comment = comment.String
|
||||
approval.ToolInput = json.RawMessage(toolInputStr)
|
||||
|
||||
approvals = append(approvals, &approval)
|
||||
}
|
||||
|
||||
return approvals, nil
|
||||
}
|
||||
|
||||
// UpdateApprovalResponse updates the status and comment of an approval
|
||||
func (s *SQLiteStore) UpdateApprovalResponse(ctx context.Context, id string, status ApprovalStatus, comment string) error {
|
||||
// Validate status
|
||||
if !status.IsValid() {
|
||||
return fmt.Errorf("invalid approval status: %s", status)
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE approvals
|
||||
SET status = ?, comment = ?, responded_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
result, err := s.db.ExecContext(ctx, query, status.String(), comment, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update approval response: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("approval not found: %s", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to convert MCP config to store format
|
||||
func MCPServersFromConfig(sessionID string, config map[string]claudecode.MCPServer) ([]MCPServer, error) {
|
||||
// First, collect all server names and sort them for deterministic ordering
|
||||
|
||||
@@ -24,6 +24,7 @@ type ConversationStore interface {
|
||||
|
||||
// Tool call operations
|
||||
GetPendingToolCall(ctx context.Context, sessionID string, toolName string) (*ConversationEvent, error)
|
||||
GetUncorrelatedPendingToolCall(ctx context.Context, sessionID string, toolName string) (*ConversationEvent, error)
|
||||
GetPendingToolCalls(ctx context.Context, sessionID string) ([]*ConversationEvent, error)
|
||||
MarkToolCallCompleted(ctx context.Context, toolID string, sessionID string) error
|
||||
CorrelateApproval(ctx context.Context, sessionID string, toolName string, approvalID string) error
|
||||
@@ -37,6 +38,12 @@ type ConversationStore interface {
|
||||
// Raw event storage (for debugging)
|
||||
StoreRawEvent(ctx context.Context, sessionID string, eventJSON string) error
|
||||
|
||||
// Approval operations for local approvals
|
||||
CreateApproval(ctx context.Context, approval *Approval) error
|
||||
GetApproval(ctx context.Context, id string) (*Approval, error)
|
||||
GetPendingApprovals(ctx context.Context, sessionID string) ([]*Approval, error)
|
||||
UpdateApprovalResponse(ctx context.Context, id string, status ApprovalStatus, comment string) error
|
||||
|
||||
// Database lifecycle
|
||||
Close() error
|
||||
}
|
||||
@@ -123,6 +130,44 @@ type MCPServer struct {
|
||||
EnvJSON string // JSON object
|
||||
}
|
||||
|
||||
// ApprovalStatus represents the status of an approval
|
||||
type ApprovalStatus string
|
||||
|
||||
// Valid approval statuses
|
||||
const (
|
||||
ApprovalStatusLocalPending ApprovalStatus = "pending"
|
||||
ApprovalStatusLocalApproved ApprovalStatus = "approved"
|
||||
ApprovalStatusLocalDenied ApprovalStatus = "denied"
|
||||
)
|
||||
|
||||
// String returns the string representation of the status
|
||||
func (s ApprovalStatus) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// IsValid checks if the status is valid
|
||||
func (s ApprovalStatus) IsValid() bool {
|
||||
switch s {
|
||||
case ApprovalStatusLocalPending, ApprovalStatusLocalApproved, ApprovalStatusLocalDenied:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Approval represents a local approval request
|
||||
type Approval struct {
|
||||
ID string `json:"id"`
|
||||
RunID string `json:"run_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Status ApprovalStatus `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
RespondedAt *time.Time `json:"responded_at,omitempty"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolInput json.RawMessage `json:"tool_input"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
// EventType constants
|
||||
const (
|
||||
EventTypeMessage = "message"
|
||||
|
||||
22
hlyr/.npmignore
Normal file
22
hlyr/.npmignore
Normal file
@@ -0,0 +1,22 @@
|
||||
# Test files and directories
|
||||
hack/
|
||||
test-local-approvals.ts
|
||||
test-*.ts
|
||||
test_local_approvals.md
|
||||
|
||||
# Test configs
|
||||
mcp-config.json
|
||||
test-mcp-logging.sh
|
||||
|
||||
# Source files (only dist is published)
|
||||
src/
|
||||
tsconfig.json
|
||||
.eslintrc.json
|
||||
.prettierrc
|
||||
|
||||
# Development files
|
||||
*.log
|
||||
.DS_Store
|
||||
node_modules/
|
||||
coverage/
|
||||
.vscode/
|
||||
@@ -245,6 +245,10 @@ npm test
|
||||
npm run dev
|
||||
```
|
||||
|
||||
### Testing Local Approvals
|
||||
|
||||
For testing the local MCP approvals system without HumanLayer API access, see [test_local_approvals.md](./test_local_approvals.md).
|
||||
|
||||
## License
|
||||
|
||||
Apache-2.0
|
||||
|
||||
465
hlyr/hack/test-local-approvals.ts
Executable file
465
hlyr/hack/test-local-approvals.ts
Executable file
@@ -0,0 +1,465 @@
|
||||
#!/usr/bin/env bun
|
||||
|
||||
import { connectWithRetry, DaemonClient, Approval } from '../src/daemonClient.js'
|
||||
import { homedir } from 'os'
|
||||
import { join } from 'path'
|
||||
import * as fs from 'fs/promises'
|
||||
import { spawn } from 'child_process'
|
||||
import { Database } from 'bun:sqlite'
|
||||
import { parseArgs } from 'util'
|
||||
|
||||
// Configuration
|
||||
const SOCKET_PATH = join(homedir(), '.humanlayer', 'daemon.sock')
|
||||
const DB_PATH = join(homedir(), '.humanlayer', 'daemon.db')
|
||||
const MCP_LOG_DIR = join(homedir(), '.humanlayer', 'logs')
|
||||
|
||||
// ANSI color codes for output
|
||||
const colors = {
|
||||
reset: '\x1b[0m',
|
||||
green: '\x1b[32m',
|
||||
red: '\x1b[31m',
|
||||
yellow: '\x1b[33m',
|
||||
blue: '\x1b[34m',
|
||||
cyan: '\x1b[36m',
|
||||
magenta: '\x1b[35m',
|
||||
}
|
||||
|
||||
function log(level: 'info' | 'success' | 'error' | 'debug' | 'mcp', message: string) {
|
||||
const timestamp = new Date().toISOString()
|
||||
const color = {
|
||||
info: colors.blue,
|
||||
success: colors.green,
|
||||
error: colors.red,
|
||||
debug: colors.cyan,
|
||||
mcp: colors.magenta,
|
||||
}[level]
|
||||
|
||||
console.log(`${color}[${timestamp}] [${level.toUpperCase()}] ${message}${colors.reset}`)
|
||||
}
|
||||
|
||||
// Monitor MCP logs in real-time
|
||||
async function monitorMCPLogs(runId: string): Promise<() => void> {
|
||||
const logPath = join(MCP_LOG_DIR, `mcp-claude-approvals-${runId}.log`)
|
||||
log('info', `Monitoring MCP logs at: ${logPath}`)
|
||||
|
||||
// Wait for log file to exist
|
||||
let attempts = 0
|
||||
while (attempts < 10) {
|
||||
try {
|
||||
await fs.access(logPath)
|
||||
break
|
||||
} catch {
|
||||
await new Promise(resolve => setTimeout(resolve, 500))
|
||||
attempts++
|
||||
}
|
||||
}
|
||||
|
||||
// Create a tail-like process to follow the log
|
||||
const tail = spawn('tail', ['-f', logPath])
|
||||
|
||||
tail.stdout.on('data', data => {
|
||||
const lines = data.toString().split('\n').filter(Boolean)
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const entry = JSON.parse(line)
|
||||
const prefix = `[${entry.level}] ${entry.message}`
|
||||
const dataStr = entry.data ? ` ${JSON.stringify(entry.data)}` : ''
|
||||
log('mcp', `${prefix}${dataStr}`)
|
||||
} catch {
|
||||
// Not JSON, just print as-is
|
||||
log('mcp', line)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
tail.stderr.on('data', data => {
|
||||
log('error', `MCP log error: ${data}`)
|
||||
})
|
||||
|
||||
// Return cleanup function
|
||||
return () => {
|
||||
tail.kill()
|
||||
}
|
||||
}
|
||||
|
||||
// Launch a session that will definitely trigger an approval
|
||||
async function launchApprovalSession(client: DaemonClient, testFile: string) {
|
||||
log('info', 'Launching session that will trigger file write approval...')
|
||||
|
||||
const launchRequest = {
|
||||
query: `Please write the text "Hello from MCP approval test!" to a file named "${testFile}". This is a test of the approval system.`,
|
||||
working_dir: process.cwd(),
|
||||
mcp_config: {
|
||||
mcpServers: {
|
||||
approvals: {
|
||||
command: 'npm',
|
||||
args: ['run', 'dev', 'mcp', 'claude_approvals'],
|
||||
env: {
|
||||
...process.env,
|
||||
HUMANLAYER_MCP_DEBUG: 'true',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
permission_prompt_tool: 'mcp__approvals__request_permission',
|
||||
}
|
||||
|
||||
const session = await client.launchSession(launchRequest)
|
||||
log('success', `Session launched: ${session.session_id}`)
|
||||
log('info', `Run ID: ${session.run_id}`)
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// Launch a session for interactive monitoring
|
||||
async function launchInteractiveSession(client: DaemonClient, query?: string) {
|
||||
// Generate random content to ensure approval is always triggered
|
||||
const timestamp = new Date().toISOString()
|
||||
const randomId = Math.random().toString(36).substring(7)
|
||||
const defaultQuery = `Please write "Hello from HumanLayer MCP test!\nTimestamp: ${timestamp}\nTest ID: ${randomId}" to a file named blah.txt`
|
||||
const userQuery = query || defaultQuery
|
||||
|
||||
log('info', 'Launching interactive session...')
|
||||
log('info', `Query: ${userQuery}`)
|
||||
|
||||
const launchRequest = {
|
||||
query: userQuery,
|
||||
working_dir: process.cwd(),
|
||||
mcp_config: {
|
||||
mcpServers: {
|
||||
approvals: {
|
||||
command: 'npm',
|
||||
args: ['run', 'dev', 'mcp', 'claude_approvals'],
|
||||
env: {
|
||||
...process.env,
|
||||
HUMANLAYER_MCP_DEBUG: 'true',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
permission_prompt_tool: 'mcp__approvals__request_permission',
|
||||
}
|
||||
|
||||
const session = await client.launchSession(launchRequest)
|
||||
log('success', `Session launched: ${session.session_id}`)
|
||||
log('info', `Run ID: ${session.run_id}`)
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// Automated test mode
|
||||
async function runAutomatedTest() {
|
||||
log('info', '=== Automated MCP Approval Test ===\n')
|
||||
|
||||
// Enable MCP debug logging
|
||||
process.env.HUMANLAYER_MCP_DEBUG = 'true'
|
||||
|
||||
// Connect to daemon
|
||||
const client = await connectWithRetry(SOCKET_PATH, 3, 1000)
|
||||
log('success', 'Connected to daemon')
|
||||
|
||||
try {
|
||||
// Generate test file name
|
||||
const testFile = `test-mcp-approval-${Date.now()}.txt`
|
||||
log('info', `Test file: ${testFile}`)
|
||||
|
||||
// Launch session
|
||||
const session = await launchApprovalSession(client, testFile)
|
||||
|
||||
// Start monitoring MCP logs
|
||||
const stopMonitoring = await monitorMCPLogs(session.run_id)
|
||||
|
||||
// Subscribe to approval events
|
||||
const eventEmitter = await client.subscribe({
|
||||
event_types: ['new_approval', 'approval_resolved'],
|
||||
session_id: session.session_id,
|
||||
})
|
||||
|
||||
let approvalReceived = false
|
||||
let approvalId: string | null = null
|
||||
|
||||
eventEmitter.on('event', event => {
|
||||
if (event.type === 'new_approval') {
|
||||
log('success', `New approval event received!`)
|
||||
log('debug', `Event data: ${JSON.stringify(event.data)}`)
|
||||
approvalReceived = true
|
||||
} else if (event.type === 'approval_resolved') {
|
||||
log('success', `Approval resolved: ${JSON.stringify(event.data)}`)
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for approval to be created
|
||||
log('info', 'Waiting for Claude to request file write approval...')
|
||||
const maxWait = 30000 // 30 seconds
|
||||
const startTime = Date.now()
|
||||
|
||||
while (!approvalId && Date.now() - startTime < maxWait) {
|
||||
// Check database for approvals using Bun's native SQLite
|
||||
const db = new Database(DB_PATH, { readonly: true })
|
||||
|
||||
const stmt = db.prepare('SELECT * FROM approvals WHERE session_id = ? AND status = "pending"')
|
||||
const approvals = stmt.all(session.session_id) as Approval[]
|
||||
|
||||
if (approvals.length > 0) {
|
||||
const approval = approvals[0]
|
||||
approvalId = approval.id
|
||||
log('success', `Approval found in database: ${approvalId}`)
|
||||
log('debug', `Tool: ${approval.tool_name}`)
|
||||
log('debug', `Input: ${approval.tool_input}`)
|
||||
|
||||
// Auto-approve after a short delay
|
||||
log('info', 'Auto-approving in 2 seconds...')
|
||||
await new Promise(resolve => setTimeout(resolve, 2000))
|
||||
|
||||
try {
|
||||
await client.sendDecision(approvalId, 'approve', 'Automated test approval')
|
||||
log('success', '✓ Approval sent successfully')
|
||||
|
||||
// Wait for file to be created
|
||||
await new Promise(resolve => setTimeout(resolve, 3000))
|
||||
|
||||
// Check if file was created
|
||||
try {
|
||||
await fs.access(testFile)
|
||||
log('success', `✓ File "${testFile}" was created successfully`)
|
||||
|
||||
// Read and display content
|
||||
const content = await fs.readFile(testFile, 'utf-8')
|
||||
log('info', `File content: ${content}`)
|
||||
|
||||
// Clean up test file
|
||||
await fs.unlink(testFile)
|
||||
log('info', 'Test file cleaned up')
|
||||
} catch {
|
||||
log('error', 'File was not created - approval may have failed')
|
||||
}
|
||||
} catch (error) {
|
||||
log('error', `Failed to approve: ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
db.close()
|
||||
|
||||
if (!approvalId) {
|
||||
await new Promise(resolve => setTimeout(resolve, 1000))
|
||||
}
|
||||
}
|
||||
|
||||
if (!approvalId) {
|
||||
log('error', 'No approval was requested within timeout period')
|
||||
}
|
||||
|
||||
// Analyze MCP logs
|
||||
const mcpLogPath = join(MCP_LOG_DIR, `mcp-claude-approvals-${session.run_id}.log`)
|
||||
try {
|
||||
const logs = await fs.readFile(mcpLogPath, 'utf-8')
|
||||
const logLines = logs.split('\n').filter(Boolean)
|
||||
|
||||
log('info', `\nMCP Log Summary:`)
|
||||
log('info', `Total entries: ${logLines.length}`)
|
||||
|
||||
const errors = logLines.filter(l => l.includes('"level":"ERROR"'))
|
||||
if (errors.length > 0) {
|
||||
log('error', `Found ${errors.length} errors in MCP logs`)
|
||||
} else {
|
||||
log('success', '✓ No errors in MCP logs')
|
||||
}
|
||||
} catch (error) {
|
||||
log('error', `Could not analyze MCP logs: ${error}`)
|
||||
}
|
||||
|
||||
// Summary
|
||||
log('info', '\n=== Test Summary ===')
|
||||
log('success', '✓ Session launched with MCP approvals')
|
||||
log('success', '✓ MCP logs monitored successfully')
|
||||
|
||||
if (approvalReceived) {
|
||||
log('success', '✓ Approval event received via subscription')
|
||||
}
|
||||
|
||||
if (approvalId) {
|
||||
log('success', '✓ Approval created and processed')
|
||||
}
|
||||
|
||||
stopMonitoring()
|
||||
} finally {
|
||||
client.close()
|
||||
// Exit cleanly after test completes
|
||||
process.exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
// Interactive monitoring mode
|
||||
async function runInteractiveMode(query?: string) {
|
||||
log('info', '=== Interactive MCP Monitoring Mode ===\n')
|
||||
|
||||
// Enable MCP debug logging
|
||||
process.env.HUMANLAYER_MCP_DEBUG = 'true'
|
||||
|
||||
// Connect to daemon
|
||||
const client = await connectWithRetry(SOCKET_PATH, 3, 1000)
|
||||
log('success', 'Connected to daemon')
|
||||
|
||||
try {
|
||||
// Launch session
|
||||
const session = await launchInteractiveSession(client, query)
|
||||
|
||||
// Start monitoring MCP logs
|
||||
const stopMonitoring = await monitorMCPLogs(session.run_id)
|
||||
|
||||
// Subscribe to approval events
|
||||
const eventEmitter = await client.subscribe({
|
||||
event_types: ['new_approval', 'approval_resolved', 'session_status_changed'],
|
||||
session_id: session.session_id,
|
||||
})
|
||||
|
||||
let sessionCompleted = false
|
||||
|
||||
eventEmitter.on('event', event => {
|
||||
if (event.type === 'new_approval') {
|
||||
log('success', '\n🔔 NEW APPROVAL REQUEST!')
|
||||
log('info', `Approval ID: ${event.data.approval_id}`)
|
||||
log('info', `Tool: ${event.data.tool_name || 'N/A'}`)
|
||||
log('info', '\nYou can approve/deny this in:')
|
||||
log('info', ' - TUI: npx humanlayer tui')
|
||||
log('info', ' - WUI: Open the desktop app')
|
||||
log('info', ` - Session URL: #/sessions/${session.session_id}\n`)
|
||||
} else if (event.type === 'approval_resolved') {
|
||||
const approved = event.data.approved ? 'approved' : 'denied'
|
||||
log('success', `✓ Approval ${approved}: ${event.data.response_text || 'No comment'}`)
|
||||
} else if (event.type === 'session_status_changed') {
|
||||
const status = event.data.new_status || event.data.status || 'unknown'
|
||||
log('info', `Session status: ${status}`)
|
||||
|
||||
// Exit when session completes or fails
|
||||
if (status === 'completed' || status === 'failed') {
|
||||
sessionCompleted = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
log('info', '\n📋 Session Information:')
|
||||
log('info', `Session ID: ${session.session_id}`)
|
||||
log('info', `Run ID: ${session.run_id}`)
|
||||
log('info', '\n🛠️ Manage approvals:')
|
||||
log('info', ' TUI: npx humanlayer tui')
|
||||
log('info', ' WUI: Open the HumanLayer desktop app')
|
||||
log('info', '\n📊 Monitoring MCP logs...')
|
||||
log('info', 'Press Ctrl+C to stop monitoring\n')
|
||||
|
||||
// Keep running until session completes or interrupted
|
||||
await new Promise(resolve => {
|
||||
const sigintHandler = () => {
|
||||
log('info', '\nStopping monitor...')
|
||||
process.removeListener('SIGINT', sigintHandler)
|
||||
resolve(undefined)
|
||||
}
|
||||
process.on('SIGINT', sigintHandler)
|
||||
|
||||
// Check periodically if session has completed
|
||||
const checkInterval = setInterval(() => {
|
||||
if (sessionCompleted) {
|
||||
clearInterval(checkInterval)
|
||||
process.removeListener('SIGINT', sigintHandler)
|
||||
log('info', '\n✨ Session completed, exiting...')
|
||||
resolve(undefined)
|
||||
}
|
||||
}, 100)
|
||||
})
|
||||
|
||||
stopMonitoring()
|
||||
} finally {
|
||||
client.close()
|
||||
process.exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
// Main entry point
|
||||
async function main() {
|
||||
// Parse command line arguments
|
||||
const { values } = parseArgs({
|
||||
args: process.argv.slice(2),
|
||||
options: {
|
||||
test: {
|
||||
type: 'boolean',
|
||||
short: 't',
|
||||
default: false,
|
||||
},
|
||||
interactive: {
|
||||
type: 'boolean',
|
||||
short: 'i',
|
||||
default: false,
|
||||
},
|
||||
query: {
|
||||
type: 'string',
|
||||
short: 'q',
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Check if hlyr is built (we're in the hlyr/hack directory)
|
||||
const hlyrDistPath = join(__dirname, '..', 'dist', 'index.js')
|
||||
try {
|
||||
await fs.access(hlyrDistPath)
|
||||
} catch {
|
||||
log('error', 'hlyr is not built. Please run: cd .. && npm install && npm run build')
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
// Ensure MCP log directory exists
|
||||
await fs.mkdir(MCP_LOG_DIR, { recursive: true })
|
||||
|
||||
try {
|
||||
if (values.test) {
|
||||
await runAutomatedTest()
|
||||
} else if (values.interactive || !values.test) {
|
||||
await runInteractiveMode(values.query)
|
||||
}
|
||||
} catch (error) {
|
||||
log('error', `Error: ${error}`)
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Show usage if needed
|
||||
if (process.argv.includes('--help') || process.argv.includes('-h')) {
|
||||
console.log(`
|
||||
Local MCP Approvals Test Tool
|
||||
|
||||
Usage:
|
||||
bun test-local-approvals.ts [options]
|
||||
|
||||
Options:
|
||||
-t, --test Run automated test (launches session, triggers approval, auto-approves)
|
||||
-i, --interactive Run in interactive mode (launches session, monitors logs, manual approval)
|
||||
-q, --query Custom query for the session (interactive mode only)
|
||||
-h, --help Show this help message
|
||||
|
||||
Examples:
|
||||
# Run automated test
|
||||
bun test-local-approvals.ts --test
|
||||
|
||||
# Interactive mode (default - will request to write to blah.txt)
|
||||
bun test-local-approvals.ts
|
||||
|
||||
# Interactive mode with custom query
|
||||
bun test-local-approvals.ts -q "Help me analyze this codebase"
|
||||
|
||||
# Interactive mode without triggering approval
|
||||
bun test-local-approvals.ts -q "Hello, how are you?"
|
||||
|
||||
Notes:
|
||||
- Run from the hlyr/hack directory
|
||||
- Make sure the daemon is running: cd ../../hld && ./dist/bin/hld -debug
|
||||
- Build hlyr first: cd .. && npm install && npm run build
|
||||
- In interactive mode, use TUI or WUI to approve/deny
|
||||
- MCP logs are saved to: ~/.humanlayer/logs/
|
||||
`)
|
||||
process.exit(0)
|
||||
}
|
||||
|
||||
// Run main
|
||||
main().catch(error => {
|
||||
console.error('Fatal error:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
@@ -40,12 +40,30 @@ interface Event {
|
||||
type: 'new_approval' | 'approval_resolved' | 'session_status_changed'
|
||||
timestamp: string
|
||||
data: {
|
||||
type?: 'function_call' | 'human_contact'
|
||||
count?: number
|
||||
// Common fields
|
||||
session_id?: string
|
||||
run_id?: string
|
||||
|
||||
// new_approval event fields
|
||||
approval_id?: string
|
||||
tool_name?: string
|
||||
|
||||
// approval_resolved event fields
|
||||
approved?: boolean
|
||||
response_text?: string
|
||||
|
||||
// session_status_changed event fields
|
||||
old_status?: string
|
||||
new_status?: string
|
||||
parent_session_id?: string
|
||||
|
||||
// Legacy fields (may not be used)
|
||||
type?: 'function_call' | 'human_contact'
|
||||
count?: number
|
||||
function_name?: string
|
||||
message?: string
|
||||
|
||||
// Allow other fields
|
||||
[key: string]: string | number | boolean | undefined
|
||||
}
|
||||
}
|
||||
@@ -54,6 +72,18 @@ interface EventNotification {
|
||||
event: Event
|
||||
}
|
||||
|
||||
export interface Approval {
|
||||
id: string
|
||||
run_id: string
|
||||
session_id: string
|
||||
status: 'pending' | 'approved' | 'denied'
|
||||
created_at: string
|
||||
responded_at?: string
|
||||
tool_name: string
|
||||
tool_input: unknown
|
||||
comment?: string
|
||||
}
|
||||
|
||||
interface LaunchSessionRequest {
|
||||
query: string
|
||||
model?: string
|
||||
@@ -316,15 +346,32 @@ export class DaemonClient extends EventEmitter {
|
||||
return this.call<{ sessions: unknown[] }>('listSessions')
|
||||
}
|
||||
|
||||
async fetchApprovals(sessionId: string): Promise<unknown[]> {
|
||||
const resp = await this.call<{ approvals: unknown[] }>('fetchApprovals', { session_id: sessionId })
|
||||
async createApproval(
|
||||
runId: string,
|
||||
toolName: string,
|
||||
toolInput: unknown,
|
||||
): Promise<{ approval_id: string }> {
|
||||
return this.call<{ approval_id: string }>('createApproval', {
|
||||
run_id: runId,
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
})
|
||||
}
|
||||
|
||||
async fetchApprovals(sessionId: string): Promise<Approval[]> {
|
||||
const resp = await this.call<{ approvals: Approval[] }>('fetchApprovals', { session_id: sessionId })
|
||||
return resp.approvals
|
||||
}
|
||||
|
||||
async sendDecision(callId: string, type: string, decision: string, comment: string): Promise<void> {
|
||||
async getApproval(approvalId: string): Promise<Approval> {
|
||||
const resp = await this.call<{ approval: Approval }>('getApproval', { approval_id: approvalId })
|
||||
return resp.approval
|
||||
}
|
||||
|
||||
async sendDecision(approvalId: string, decision: string, comment: string): Promise<void> {
|
||||
const resp = await this.call<{ success: boolean; error?: string }>('sendDecision', {
|
||||
call_id: callId,
|
||||
type,
|
||||
call_id: approvalId, // Using call_id for backward compatibility
|
||||
type: 'function_call', // Always function_call for local approvals
|
||||
decision,
|
||||
comment,
|
||||
})
|
||||
|
||||
150
hlyr/src/mcp.ts
150
hlyr/src/mcp.ts
@@ -8,6 +8,8 @@ import {
|
||||
} from '@modelcontextprotocol/sdk/types.js'
|
||||
import { humanlayer } from '@humanlayer/sdk'
|
||||
import { resolveFullConfig } from './config.js'
|
||||
import { DaemonClient } from './daemonClient.js'
|
||||
import { logger } from './mcpLogger.js'
|
||||
|
||||
function validateAuth(): void {
|
||||
const config = resolveFullConfig({})
|
||||
@@ -95,13 +97,16 @@ export async function startDefaultMCPServer() {
|
||||
/**
|
||||
* Start the Claude approvals MCP server that provides request_permission functionality
|
||||
* Returns responses in the format required by Claude Code SDK
|
||||
*
|
||||
* This now uses local approvals through the daemon instead of HumanLayer API
|
||||
*/
|
||||
export async function startClaudeApprovalsMCPServer() {
|
||||
validateAuth()
|
||||
// No auth validation needed - uses local daemon
|
||||
logger.info('Starting Claude approvals MCP server')
|
||||
|
||||
const server = new Server(
|
||||
{
|
||||
name: 'humanlayer-claude-approvals',
|
||||
name: 'humanlayer-claude-local-approvals',
|
||||
version: '1.0.0',
|
||||
},
|
||||
{
|
||||
@@ -111,16 +116,8 @@ export async function startClaudeApprovalsMCPServer() {
|
||||
},
|
||||
)
|
||||
|
||||
const resolvedConfig = resolveFullConfig({})
|
||||
|
||||
const hl = humanlayer({
|
||||
apiKey: resolvedConfig.api_key,
|
||||
...(resolvedConfig.api_base_url && { apiBaseUrl: resolvedConfig.api_base_url }),
|
||||
...(resolvedConfig.run_id && { runId: resolvedConfig.run_id }),
|
||||
...(Object.keys(resolvedConfig.contact_channel).length > 0 && {
|
||||
contactChannel: resolvedConfig.contact_channel,
|
||||
}),
|
||||
})
|
||||
// Create daemon client
|
||||
const daemonClient = new DaemonClient()
|
||||
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||
return {
|
||||
@@ -142,56 +139,114 @@ export async function startClaudeApprovalsMCPServer() {
|
||||
})
|
||||
|
||||
server.setRequestHandler(CallToolRequestSchema, async request => {
|
||||
/**
|
||||
* example input
|
||||
* {
|
||||
* "tool_name": "Write",
|
||||
* "input": {
|
||||
* "file_name": "hello.txt"
|
||||
* "content": "Hello, how are you?"
|
||||
* }
|
||||
* }
|
||||
*/
|
||||
logger.debug('Received tool call request', { name: request.params.name })
|
||||
|
||||
if (request.params.name === 'request_permission') {
|
||||
const toolName: string | undefined = request.params.arguments?.tool_name
|
||||
|
||||
if (!toolName) {
|
||||
logger.error('Invalid tool name in request_permission', request.params.arguments)
|
||||
throw new McpError(ErrorCode.InvalidRequest, 'Invalid tool name requesting permissions')
|
||||
}
|
||||
|
||||
const input: Record<string, unknown> = request.params.arguments?.input || {}
|
||||
|
||||
const approvalResult = await hl.fetchHumanApproval({
|
||||
spec: {
|
||||
fn: toolName,
|
||||
kwargs: input,
|
||||
},
|
||||
})
|
||||
// Get run ID from environment (set by Claude Code)
|
||||
const runId = process.env.HUMANLAYER_RUN_ID
|
||||
if (!runId) {
|
||||
logger.error('HUMANLAYER_RUN_ID not set in environment')
|
||||
throw new McpError(ErrorCode.InternalError, 'HUMANLAYER_RUN_ID not set')
|
||||
}
|
||||
|
||||
if (!approvalResult.approved) {
|
||||
logger.info('Processing approval request', { runId, toolName })
|
||||
|
||||
try {
|
||||
// Connect to daemon
|
||||
logger.debug('Connecting to daemon...')
|
||||
await daemonClient.connect()
|
||||
logger.debug('Connected to daemon')
|
||||
|
||||
// Create approval request
|
||||
logger.debug('Creating approval request...', { runId, toolName })
|
||||
const createResponse = await daemonClient.createApproval(runId, toolName, input)
|
||||
const approvalId = createResponse.approval_id
|
||||
logger.info('Created approval', { approvalId })
|
||||
|
||||
// Poll for approval status
|
||||
let approved = false
|
||||
let comment = ''
|
||||
let polling = true
|
||||
|
||||
while (polling) {
|
||||
try {
|
||||
// Get the specific approval by ID
|
||||
logger.debug('Fetching approval status...', { approvalId })
|
||||
const approval = (await daemonClient.getApproval(approvalId)) as {
|
||||
id: string
|
||||
status: string
|
||||
comment?: string
|
||||
}
|
||||
|
||||
logger.debug('Approval status', { status: approval.status })
|
||||
|
||||
if (approval.status !== 'pending') {
|
||||
// Approval has been resolved
|
||||
approved = approval.status === 'approved'
|
||||
comment = approval.comment || ''
|
||||
polling = false
|
||||
logger.info('Approval resolved', {
|
||||
approvalId,
|
||||
status: approval.status,
|
||||
approved,
|
||||
})
|
||||
} else {
|
||||
// Still pending, wait and poll again
|
||||
logger.debug('Approval still pending, polling again...')
|
||||
await new Promise(resolve => setTimeout(resolve, 1000))
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to get approval status', { error, approvalId })
|
||||
// Re-throw the error since this is a critical failure
|
||||
throw new McpError(ErrorCode.InternalError, 'Failed to get approval status')
|
||||
}
|
||||
}
|
||||
|
||||
if (!approved) {
|
||||
logger.info('Approval denied', { approvalId, comment })
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
behavior: 'deny',
|
||||
message: comment || 'Request denied by human reviewer',
|
||||
}),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
logger.info('Approval granted', { approvalId })
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
behavior: 'deny',
|
||||
message: approvalResult.comment || 'Request denied by human reviewer',
|
||||
behavior: 'allow',
|
||||
updatedInput: input,
|
||||
}),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
behavior: 'allow',
|
||||
updatedInput: input,
|
||||
}),
|
||||
},
|
||||
],
|
||||
} catch (error) {
|
||||
logger.error('Failed to process approval', error)
|
||||
throw new McpError(
|
||||
ErrorCode.InternalError,
|
||||
`Failed to process approval: ${error instanceof Error ? error.message : String(error)}`,
|
||||
)
|
||||
} finally {
|
||||
logger.debug('Closing daemon connection')
|
||||
daemonClient.close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,5 +254,12 @@ export async function startClaudeApprovalsMCPServer() {
|
||||
})
|
||||
|
||||
const transport = new StdioServerTransport()
|
||||
await server.connect(transport)
|
||||
|
||||
try {
|
||||
await server.connect(transport)
|
||||
logger.info('MCP server connected and ready')
|
||||
} catch (error) {
|
||||
logger.error('Failed to start MCP server', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
141
hlyr/src/mcpLogger.ts
Normal file
141
hlyr/src/mcpLogger.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import * as fs from 'fs'
|
||||
import * as path from 'path'
|
||||
import * as os from 'os'
|
||||
|
||||
interface LogEntry {
|
||||
timestamp: string
|
||||
level: 'DEBUG' | 'INFO' | 'WARN' | 'ERROR'
|
||||
component: string
|
||||
message: string
|
||||
data?: unknown
|
||||
}
|
||||
|
||||
class MCPLogger {
|
||||
private logDir: string
|
||||
private logPath: string
|
||||
private isDebug: boolean
|
||||
private writeStream?: fs.WriteStream
|
||||
|
||||
constructor() {
|
||||
// Check if debug mode is enabled
|
||||
this.isDebug = process.env.HUMANLAYER_MCP_DEBUG === 'true'
|
||||
|
||||
// Create log directory
|
||||
this.logDir = path.join(os.homedir(), '.humanlayer', 'logs')
|
||||
if (!fs.existsSync(this.logDir)) {
|
||||
fs.mkdirSync(this.logDir, { recursive: true })
|
||||
}
|
||||
|
||||
// Create log file with run ID if available, otherwise use date
|
||||
const runId = process.env.HUMANLAYER_RUN_ID
|
||||
const identifier = runId || new Date().toISOString().split('T')[0]
|
||||
this.logPath = path.join(this.logDir, `mcp-claude-approvals-${identifier}.log`)
|
||||
|
||||
// Open write stream in append mode
|
||||
this.writeStream = fs.createWriteStream(this.logPath, { flags: 'a' })
|
||||
|
||||
// Log startup
|
||||
this.info('MCP Logger initialized', {
|
||||
logPath: this.logPath,
|
||||
debug: this.isDebug,
|
||||
runId: process.env.HUMANLAYER_RUN_ID,
|
||||
})
|
||||
}
|
||||
|
||||
private write(entry: LogEntry): void {
|
||||
if (!this.writeStream) return
|
||||
|
||||
const line = JSON.stringify(entry) + '\n'
|
||||
this.writeStream.write(line)
|
||||
}
|
||||
|
||||
private shouldLog(level: 'DEBUG' | 'INFO' | 'WARN' | 'ERROR'): boolean {
|
||||
// Always log warnings and errors
|
||||
if (level === 'WARN' || level === 'ERROR') return true
|
||||
|
||||
// Only log debug/info if debug mode is enabled
|
||||
return this.isDebug
|
||||
}
|
||||
|
||||
debug(message: string, data?: unknown): void {
|
||||
if (!this.shouldLog('DEBUG')) return
|
||||
|
||||
this.write({
|
||||
timestamp: new Date().toISOString(),
|
||||
level: 'DEBUG',
|
||||
component: 'mcp-claude-approvals',
|
||||
message,
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
info(message: string, data?: unknown): void {
|
||||
if (!this.shouldLog('INFO')) return
|
||||
|
||||
this.write({
|
||||
timestamp: new Date().toISOString(),
|
||||
level: 'INFO',
|
||||
component: 'mcp-claude-approvals',
|
||||
message,
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
warn(message: string, data?: unknown): void {
|
||||
if (!this.shouldLog('WARN')) return
|
||||
|
||||
this.write({
|
||||
timestamp: new Date().toISOString(),
|
||||
level: 'WARN',
|
||||
component: 'mcp-claude-approvals',
|
||||
message,
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
error(message: string, error?: unknown): void {
|
||||
if (!this.shouldLog('ERROR')) return
|
||||
|
||||
let errorData: unknown = error
|
||||
if (error instanceof Error) {
|
||||
errorData = {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
stack: error.stack,
|
||||
}
|
||||
}
|
||||
|
||||
this.write({
|
||||
timestamp: new Date().toISOString(),
|
||||
level: 'ERROR',
|
||||
component: 'mcp-claude-approvals',
|
||||
message,
|
||||
data: errorData,
|
||||
})
|
||||
}
|
||||
|
||||
close(): void {
|
||||
if (this.writeStream) {
|
||||
this.writeStream.end()
|
||||
this.writeStream = undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const logger = new MCPLogger()
|
||||
|
||||
// Handle process exit
|
||||
process.on('exit', () => {
|
||||
logger.close()
|
||||
})
|
||||
|
||||
process.on('SIGINT', () => {
|
||||
logger.close()
|
||||
process.exit(0)
|
||||
})
|
||||
|
||||
process.on('SIGTERM', () => {
|
||||
logger.close()
|
||||
process.exit(0)
|
||||
})
|
||||
125
hlyr/test_local_approvals.md
Normal file
125
hlyr/test_local_approvals.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# Testing Local MCP Approvals
|
||||
|
||||
This guide explains how to test the local MCP approvals system without requiring HumanLayer API access.
|
||||
|
||||
## Overview
|
||||
|
||||
The `hack/test-local-approvals.ts` script provides a comprehensive testing tool for verifying that the MCP server, daemon, and approval flow are working correctly with local-only approvals.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Build hlyr and the daemon:
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
2. Start the daemon with debug logging:
|
||||
|
||||
```bash
|
||||
./dist/bin/hld -debug
|
||||
```
|
||||
|
||||
3. Have Bun installed (for running TypeScript directly)
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Automated Test Mode
|
||||
|
||||
Launches a Claude session, triggers a file write approval, and automatically approves it after 2 seconds:
|
||||
|
||||
```bash
|
||||
bun hack/test-local-approvals.ts --test
|
||||
```
|
||||
|
||||
This mode is useful for:
|
||||
|
||||
- CI/CD pipelines
|
||||
- Quick verification that the system is working
|
||||
- Debugging the approval flow
|
||||
|
||||
### Interactive Mode (Default)
|
||||
|
||||
Launches a Claude session with a query that will trigger an approval, then monitors for events:
|
||||
|
||||
```bash
|
||||
# Default query (writes to blah.txt with random content)
|
||||
bun hack/test-local-approvals.ts
|
||||
|
||||
# Custom query
|
||||
bun hack/test-local-approvals.ts -q "Help me analyze this codebase"
|
||||
|
||||
# Query that won't trigger approvals
|
||||
bun hack/test-local-approvals.ts -q "Hello, how are you?"
|
||||
```
|
||||
|
||||
While running in interactive mode:
|
||||
|
||||
- Approval requests will be highlighted in the console
|
||||
- Use TUI (`npx humanlayer tui`) or WUI to approve/deny
|
||||
- Press Ctrl+C to stop monitoring
|
||||
|
||||
## What the Test Does
|
||||
|
||||
1. **Connects to the daemon** via Unix socket
|
||||
2. **Launches a Claude session** with MCP approvals enabled
|
||||
3. **Monitors MCP logs** in real-time at `~/.humanlayer/logs/`
|
||||
4. **Subscribes to daemon events**:
|
||||
- `new_approval` - When an approval is requested
|
||||
- `approval_resolved` - When an approval is approved/denied
|
||||
- `session_status_changed` - When session status changes
|
||||
5. **In test mode**: Automatically approves after 2 seconds
|
||||
6. **In interactive mode**: Waits for manual approval via TUI/WUI
|
||||
|
||||
## Understanding the Output
|
||||
|
||||
### Successful Automated Test
|
||||
|
||||
```
|
||||
[INFO] === Automated MCP Approval Test ===
|
||||
[SUCCESS] Connected to daemon
|
||||
[SUCCESS] Session launched: <session-id>
|
||||
[SUCCESS] New approval event received!
|
||||
[SUCCESS] ✓ Approval sent successfully
|
||||
[SUCCESS] ✓ File "test-mcp-approval-XXX.txt" was created successfully
|
||||
[SUCCESS] ✓ No errors in MCP logs
|
||||
```
|
||||
|
||||
### Interactive Mode Events
|
||||
|
||||
```
|
||||
🔔 NEW APPROVAL REQUEST!
|
||||
Approval ID: local-XXXX
|
||||
Tool: Write
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Failed to connect to daemon"
|
||||
|
||||
- Ensure the daemon is running: `./dist/bin/hld -debug`
|
||||
- Check the socket exists: `ls ~/.humanlayer/daemon.sock`
|
||||
|
||||
### "hlyr is not built"
|
||||
|
||||
- Run `npm run build` from the hlyr directory
|
||||
|
||||
### No approval triggered
|
||||
|
||||
- The default query includes random content to ensure uniqueness
|
||||
- If using a custom query, make sure it requests an action (like writing a file)
|
||||
|
||||
### MCP errors in logs
|
||||
|
||||
- Check `~/.humanlayer/logs/mcp-claude-approvals-*.log` for details
|
||||
- Ensure you're using the latest built version
|
||||
|
||||
## Command Reference
|
||||
|
||||
```bash
|
||||
Options:
|
||||
-t, --test Run automated test
|
||||
-i, --interactive Run in interactive mode (default)
|
||||
-q, --query Custom query for the session
|
||||
-h, --help Show help message
|
||||
```
|
||||
@@ -1,9 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
|
||||
@@ -131,14 +131,18 @@ func fetchRequests(daemonClient client.Client) tea.Cmd {
|
||||
|
||||
// Convert approvals to our Request type
|
||||
for _, approval := range approvals {
|
||||
if approval.Type == "function_call" && approval.FunctionCall != nil {
|
||||
fc := approval.FunctionCall
|
||||
// Build a message from the function name and kwargs
|
||||
message := fmt.Sprintf("Call %s", fc.Spec.Fn)
|
||||
if len(fc.Spec.Kwargs) > 0 {
|
||||
// For local approvals, we always have tool calls
|
||||
// Build a message from the tool name and input
|
||||
message := fmt.Sprintf("Call %s", approval.ToolName)
|
||||
|
||||
// Parse tool input to extract parameters
|
||||
var toolParams map[string]interface{}
|
||||
if len(approval.ToolInput) > 0 {
|
||||
// Try to parse as JSON to get parameters
|
||||
if err := json.Unmarshal(approval.ToolInput, &toolParams); err == nil && len(toolParams) > 0 {
|
||||
// Add first few parameters to message
|
||||
params := []string{}
|
||||
for k, v := range fc.Spec.Kwargs {
|
||||
for k, v := range toolParams {
|
||||
params = append(params, fmt.Sprintf("%s=%v", k, v))
|
||||
if len(params) >= 2 {
|
||||
break
|
||||
@@ -146,62 +150,30 @@ func fetchRequests(daemonClient client.Client) tea.Cmd {
|
||||
}
|
||||
message += fmt.Sprintf(" with %s", strings.Join(params, ", "))
|
||||
}
|
||||
|
||||
createdAt := time.Now() // Default to now if not available
|
||||
if fc.Status != nil && fc.Status.RequestedAt != nil {
|
||||
createdAt = fc.Status.RequestedAt.Time
|
||||
}
|
||||
|
||||
req := Request{
|
||||
ID: fc.CallID,
|
||||
CallID: fc.CallID,
|
||||
RunID: fc.RunID,
|
||||
Type: ApprovalRequest,
|
||||
Message: message,
|
||||
Tool: fc.Spec.Fn,
|
||||
Parameters: fc.Spec.Kwargs,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
// Enrich with session info if available
|
||||
if sess, ok := sessionsByRunID[fc.RunID]; ok {
|
||||
req.SessionID = sess.ID
|
||||
req.SessionQuery = truncate(sess.Query, 50)
|
||||
req.SessionModel = sess.Model
|
||||
if req.SessionModel == "" {
|
||||
req.SessionModel = "default"
|
||||
}
|
||||
}
|
||||
|
||||
allRequests = append(allRequests, req)
|
||||
} else if approval.Type == "human_contact" && approval.HumanContact != nil {
|
||||
hc := approval.HumanContact
|
||||
createdAt := time.Now() // Default to now if not available
|
||||
if hc.Status != nil && hc.Status.RequestedAt != nil {
|
||||
createdAt = hc.Status.RequestedAt.Time
|
||||
}
|
||||
|
||||
req := Request{
|
||||
ID: hc.CallID,
|
||||
CallID: hc.CallID,
|
||||
RunID: hc.RunID,
|
||||
Type: HumanContactRequest,
|
||||
Message: hc.Spec.Msg,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
// Enrich with session info if available
|
||||
if sess, ok := sessionsByRunID[hc.RunID]; ok {
|
||||
req.SessionID = sess.ID
|
||||
req.SessionQuery = truncate(sess.Query, 50)
|
||||
req.SessionModel = sess.Model
|
||||
if req.SessionModel == "" {
|
||||
req.SessionModel = "default"
|
||||
}
|
||||
}
|
||||
|
||||
allRequests = append(allRequests, req)
|
||||
}
|
||||
|
||||
req := Request{
|
||||
ID: approval.ID,
|
||||
CallID: approval.ID, // Use approval ID as call ID
|
||||
RunID: approval.RunID,
|
||||
Type: ApprovalRequest,
|
||||
Message: message,
|
||||
Tool: approval.ToolName,
|
||||
Parameters: toolParams,
|
||||
CreatedAt: approval.CreatedAt,
|
||||
}
|
||||
|
||||
// Enrich with session info if available
|
||||
if sess, ok := sessionsByRunID[approval.RunID]; ok {
|
||||
req.SessionID = sess.ID
|
||||
req.SessionQuery = truncate(sess.Query, 50)
|
||||
req.SessionModel = sess.Model
|
||||
if req.SessionModel == "" {
|
||||
req.SessionModel = "default"
|
||||
}
|
||||
}
|
||||
|
||||
allRequests = append(allRequests, req)
|
||||
}
|
||||
|
||||
return fetchRequestsMsg{requests: allRequests}
|
||||
@@ -247,14 +219,17 @@ func fetchSessionApprovals(daemonClient client.Client, sessionID string) tea.Cmd
|
||||
// Convert to Request type
|
||||
var requests []Request
|
||||
for _, approval := range approvals {
|
||||
if approval.Type == "function_call" && approval.FunctionCall != nil {
|
||||
fc := approval.FunctionCall
|
||||
message := fmt.Sprintf("Call %s", fc.Spec.Fn)
|
||||
// For local approvals, we always have tool calls
|
||||
message := fmt.Sprintf("Call %s", approval.ToolName)
|
||||
|
||||
// Add parameters to message
|
||||
if len(fc.Spec.Kwargs) > 0 {
|
||||
// Parse tool input to extract parameters
|
||||
var toolParams map[string]interface{}
|
||||
if len(approval.ToolInput) > 0 {
|
||||
// Try to parse as JSON to get parameters
|
||||
if err := json.Unmarshal(approval.ToolInput, &toolParams); err == nil && len(toolParams) > 0 {
|
||||
// Add first few parameters to message
|
||||
params := []string{}
|
||||
for k, v := range fc.Spec.Kwargs {
|
||||
for k, v := range toolParams {
|
||||
params = append(params, fmt.Sprintf("%s=%v", k, v))
|
||||
if len(params) >= 2 {
|
||||
break
|
||||
@@ -262,62 +237,30 @@ func fetchSessionApprovals(daemonClient client.Client, sessionID string) tea.Cmd
|
||||
}
|
||||
message += fmt.Sprintf(" with %s", strings.Join(params, ", "))
|
||||
}
|
||||
|
||||
createdAt := time.Now()
|
||||
if fc.Status != nil && fc.Status.RequestedAt != nil {
|
||||
createdAt = fc.Status.RequestedAt.Time
|
||||
}
|
||||
|
||||
req := Request{
|
||||
ID: fc.CallID,
|
||||
CallID: fc.CallID,
|
||||
RunID: fc.RunID,
|
||||
Type: ApprovalRequest,
|
||||
Message: message,
|
||||
Tool: fc.Spec.Fn,
|
||||
Parameters: fc.Spec.Kwargs,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
// Add session info if available
|
||||
if sessionInfo != nil {
|
||||
req.SessionID = sessionInfo.ID
|
||||
req.SessionQuery = truncate(sessionInfo.Query, 50)
|
||||
req.SessionModel = sessionInfo.Model
|
||||
if req.SessionModel == "" {
|
||||
req.SessionModel = "default"
|
||||
}
|
||||
}
|
||||
|
||||
requests = append(requests, req)
|
||||
} else if approval.Type == "human_contact" && approval.HumanContact != nil {
|
||||
hc := approval.HumanContact
|
||||
createdAt := time.Now()
|
||||
if hc.Status != nil && hc.Status.RequestedAt != nil {
|
||||
createdAt = hc.Status.RequestedAt.Time
|
||||
}
|
||||
|
||||
req := Request{
|
||||
ID: hc.CallID,
|
||||
CallID: hc.CallID,
|
||||
RunID: hc.RunID,
|
||||
Type: HumanContactRequest,
|
||||
Message: hc.Spec.Msg,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
// Add session info if available
|
||||
if sessionInfo != nil {
|
||||
req.SessionID = sessionInfo.ID
|
||||
req.SessionQuery = truncate(sessionInfo.Query, 50)
|
||||
req.SessionModel = sessionInfo.Model
|
||||
if req.SessionModel == "" {
|
||||
req.SessionModel = "default"
|
||||
}
|
||||
}
|
||||
|
||||
requests = append(requests, req)
|
||||
}
|
||||
|
||||
req := Request{
|
||||
ID: approval.ID,
|
||||
CallID: approval.ID, // Use approval ID as call ID
|
||||
RunID: approval.RunID,
|
||||
Type: ApprovalRequest,
|
||||
Message: message,
|
||||
Tool: approval.ToolName,
|
||||
Parameters: toolParams,
|
||||
CreatedAt: approval.CreatedAt,
|
||||
}
|
||||
|
||||
// Add session info if available
|
||||
if sessionInfo != nil {
|
||||
req.SessionID = sessionInfo.ID
|
||||
req.SessionQuery = truncate(sessionInfo.Query, 50)
|
||||
req.SessionModel = sessionInfo.Model
|
||||
if req.SessionModel == "" {
|
||||
req.SessionModel = "default"
|
||||
}
|
||||
}
|
||||
|
||||
requests = append(requests, req)
|
||||
}
|
||||
|
||||
return fetchSessionApprovalsMsg{approvals: requests}
|
||||
@@ -386,9 +329,9 @@ func sendApproval(daemonClient client.Client, callID string, approved bool, comm
|
||||
return func() tea.Msg {
|
||||
var err error
|
||||
if approved {
|
||||
err = daemonClient.ApproveFunctionCall(callID, comment)
|
||||
err = daemonClient.ApproveToolCall(callID, comment)
|
||||
} else {
|
||||
err = daemonClient.DenyFunctionCall(callID, comment)
|
||||
err = daemonClient.DenyToolCall(callID, comment)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -401,7 +344,8 @@ func sendApproval(daemonClient client.Client, callID string, approved bool, comm
|
||||
|
||||
func sendHumanResponse(daemonClient client.Client, requestID string, response string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := daemonClient.RespondToHumanContact(requestID, response)
|
||||
// Human contact is no longer supported in local approvals
|
||||
err := fmt.Errorf("human contact approvals are no longer supported")
|
||||
if err != nil {
|
||||
return humanResponseSentMsg{err: err}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user