Correlate Approvals with Tool Calls (#206)

* initial phase 3 implementation

* fix(approval): correlate approvals by looking up session ID from run_id

  - Fix GetPendingToolCall call to use session ID instead of run_id
  - Add session lookup by run_id before attempting correlation
  - Improve debug logging for correlation process

* fix(store): handle in-memory SQLite databases correctly

  - Skip directory creation for :memory: databases
  - Prevents error when using in-memory databases for testing

* test: add comprehensive tests for approval correlation

  - Unit tests for approval correlation logic
  - Unit tests for session status transitions
  - Integration test for orphaned session cleanup
  - Integration test for full approval flow with session state changes

* ignore go warnings in commands

* switch to getting session by run id (way better)
This commit is contained in:
Allison Durham
2025-06-06 21:52:49 -05:00
committed by GitHub
parent e21e3944c7
commit 28bbe4ba08
12 changed files with 1041 additions and 55 deletions

View File

@@ -9,15 +9,15 @@ test: test-unit test-integration
# Run unit tests with race detection
test-unit:
go test -v -race ./...
CGO_LDFLAGS="-Wl,-w" go test -v -race ./...
# Run integration tests (requires build tag)
test-integration: build
go test -v -tags=integration -run Integration ./daemon/...
CGO_LDFLAGS="-Wl,-w" go test -v -tags=integration -run Integration ./daemon/...
# Run tests with race detection
test-race:
go test -race ./...
CGO_LDFLAGS="-Wl,-w" go test -race ./...
# Clean build artifacts
clean:

View File

@@ -0,0 +1,157 @@
package approval
import (
"context"
"fmt"
"testing"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
"go.uber.org/mock/gomock"
)
func TestPoller_CorrelateApproval(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tests := []struct {
name string
functionCall humanlayer.FunctionCall
setupStore func(mockStore *store.MockConversationStore)
expectCorrelate bool
expectStatusUpdate bool
}{
{
name: "successful_correlation",
functionCall: humanlayer.FunctionCall{
CallID: "fc-123",
RunID: "run-456",
Spec: humanlayer.FunctionCallSpec{
Fn: "dangerous_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// Expect GetSessionByRunID to find the session
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "run-456").
Return(&store.Session{
ID: "sess-789",
RunID: "run-456",
}, nil)
// Return a pending tool call
mockStore.EXPECT().
GetPendingToolCall(gomock.Any(), "sess-789", "dangerous_function").
Return(&store.ConversationEvent{
ID: 1,
SessionID: "sess-789",
ToolName: "dangerous_function",
}, nil)
// Expect correlation
mockStore.EXPECT().
CorrelateApproval(gomock.Any(), "sess-789", "dangerous_function", "fc-123").
Return(nil)
// Expect status update to waiting_input
mockStore.EXPECT().
UpdateSession(gomock.Any(), "sess-789", gomock.Any()).
DoAndReturn(func(ctx context.Context, sessionID string, update store.SessionUpdate) error {
if update.Status == nil || *update.Status != store.SessionStatusWaitingInput {
t.Errorf("expected status update to waiting_input, got %v", update.Status)
}
return nil
})
},
expectCorrelate: true,
expectStatusUpdate: true,
},
{
name: "no_matching_session",
functionCall: humanlayer.FunctionCall{
CallID: "fc-999",
RunID: "unknown-run",
Spec: humanlayer.FunctionCallSpec{
Fn: "some_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// GetSessionByRunID returns nil (no matching session)
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "unknown-run").
Return(nil, nil)
// No further calls should happen
},
expectCorrelate: false,
expectStatusUpdate: false,
},
{
name: "nil_tool_call_returned",
functionCall: humanlayer.FunctionCall{
CallID: "fc-nil",
RunID: "run-nil",
Spec: humanlayer.FunctionCallSpec{
Fn: "nil_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// GetSessionByRunID returns a matching session
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "run-nil").
Return(&store.Session{
ID: "sess-nil",
RunID: "run-nil",
}, nil)
// Return nil tool call (no error but no result)
mockStore.EXPECT().
GetPendingToolCall(gomock.Any(), "sess-nil", "nil_function").
Return(nil, nil)
// No correlation or status update should happen
},
expectCorrelate: false,
expectStatusUpdate: false,
},
{
name: "get_session_error",
functionCall: humanlayer.FunctionCall{
CallID: "fc-error",
RunID: "run-error",
Spec: humanlayer.FunctionCallSpec{
Fn: "error_function",
},
},
setupStore: func(mockStore *store.MockConversationStore) {
// GetSessionByRunID returns an error
mockStore.EXPECT().
GetSessionByRunID(gomock.Any(), "run-error").
Return(nil, fmt.Errorf("database error"))
// No further calls should happen
},
expectCorrelate: false,
expectStatusUpdate: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock store
mockStore := store.NewMockConversationStore(ctrl)
tt.setupStore(mockStore)
// Create poller with mock store
poller := &Poller{
conversationStore: mockStore,
}
// Call correlateApproval
ctx := context.Background()
poller.correlateApproval(ctx, tt.functionCall)
// The expectations are verified by gomock
})
}
}

View File

@@ -3,9 +3,11 @@ package approval
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
@@ -20,14 +22,15 @@ type Config struct {
// DefaultManager is the default implementation of Manager
type DefaultManager struct {
Client APIClient
Store Store
Poller *Poller
EventBus bus.EventBus
Client APIClient
Store Store
Poller *Poller
EventBus bus.EventBus
ConversationStore store.ConversationStore
}
// NewManager creates a new approval manager
func NewManager(cfg Config, eventBus bus.EventBus) (Manager, error) {
func NewManager(cfg Config, eventBus bus.EventBus, conversationStore store.ConversationStore) (Manager, error) {
// Set defaults
if cfg.PollInterval <= 0 {
cfg.PollInterval = 5 * time.Second
@@ -58,15 +61,16 @@ func NewManager(cfg Config, eventBus bus.EventBus) (Manager, error) {
store := NewMemoryStore()
// Create poller with configured interval
poller := NewPoller(client, store, cfg.PollInterval, eventBus)
poller := NewPoller(client, store, conversationStore, cfg.PollInterval, eventBus)
poller.maxBackoff = cfg.MaxBackoff
poller.backoffFactor = cfg.BackoffFactor
return &DefaultManager{
Client: client,
Store: store,
Poller: poller,
EventBus: eventBus,
Client: client,
Store: store,
Poller: poller,
EventBus: eventBus,
ConversationStore: conversationStore,
}, nil
}
@@ -82,10 +86,16 @@ func (m *DefaultManager) Stop() {
// GetPendingApprovals returns all pending approvals, optionally filtered by session
func (m *DefaultManager) GetPendingApprovals(sessionID string) ([]PendingApproval, error) {
if sessionID != "" {
// In the future, we'll need to look up the run_id for this session
// For now, return empty since we don't have session->run_id mapping yet
return []PendingApproval{}, nil
if sessionID != "" && m.ConversationStore != nil {
// Look up the session to get its run_id
ctx := context.Background()
session, err := m.ConversationStore.GetSession(ctx, sessionID)
if err != nil {
slog.Debug("session not found for approval filter", "session_id", sessionID, "error", err)
return []PendingApproval{}, nil
}
// Get approvals by run_id
return m.Store.GetPendingByRunID(session.RunID)
}
return m.Store.GetAllPending()
}
@@ -113,9 +123,40 @@ func (m *DefaultManager) ApproveFunctionCall(ctx context.Context, callID string,
return fmt.Errorf("failed to update local state: %w", err)
}
// Get the function call to access run_id
fc, _ := m.Store.GetFunctionCall(callID)
// Update approval status in database if we have a conversation store
if m.ConversationStore != nil && fc != nil {
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusApproved); err != nil {
slog.Error("failed to update approval status in database", "error", err)
// Don't fail the whole operation for this
}
// Update session status back to running since approval is resolved
if fc.RunID != "" {
// Look up the session by run_id
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
runningStatus := store.SessionStatusRunning
update := store.SessionUpdate{
Status: &runningStatus,
}
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status to running",
"session_id", session.ID,
"error", err)
} else {
slog.Info("updated session status back to running after approval",
"session_id", session.ID,
"approval_id", callID)
}
}
}
}
// Publish event
if m.EventBus != nil {
fc, _ := m.Store.GetFunctionCall(callID)
if m.EventBus != nil && fc != nil {
m.EventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
Data: map[string]interface{}{
@@ -134,7 +175,7 @@ func (m *DefaultManager) ApproveFunctionCall(ctx context.Context, callID string,
// DenyFunctionCall denies a function call
func (m *DefaultManager) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
// First check if we have this function call
_, err := m.Store.GetFunctionCall(callID)
fc, err := m.Store.GetFunctionCall(callID)
if err != nil {
return fmt.Errorf("function call not found: %w", err)
}
@@ -149,9 +190,37 @@ func (m *DefaultManager) DenyFunctionCall(ctx context.Context, callID string, re
return fmt.Errorf("failed to update local state: %w", err)
}
// Update approval status in database if we have a conversation store
if m.ConversationStore != nil && fc != nil {
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusDenied); err != nil {
slog.Error("failed to update approval status in database", "error", err)
// Don't fail the whole operation for this
}
// Update session status back to running since approval is resolved (denied)
if fc.RunID != "" {
// Look up the session by run_id
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
runningStatus := store.SessionStatusRunning
update := store.SessionUpdate{
Status: &runningStatus,
}
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status to running",
"session_id", session.ID,
"error", err)
} else {
slog.Info("updated session status back to running after denial",
"session_id", session.ID,
"approval_id", callID)
}
}
}
}
// Publish event
if m.EventBus != nil {
fc, _ := m.Store.GetFunctionCall(callID)
if m.EventBus != nil && fc != nil {
m.EventBus.Publish(bus.Event{
Type: bus.EventApprovalResolved,
Data: map[string]interface{}{

View File

@@ -0,0 +1,196 @@
package approval
import (
"context"
"testing"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
"go.uber.org/mock/gomock"
)
func TestManager_SessionStatusTransitions(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tests := []struct {
name string
setupTest func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string // returns callID
expectedStatus string
}{
{
name: "approve_updates_session_to_running",
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
callID := "fc-approve-123"
runID := "run-approve-456"
sessionID := "sess-approve-789"
// Function call exists
fc := &humanlayer.FunctionCall{
CallID: callID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
},
}
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).AnyTimes()
// API approval succeeds
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
// Mark as responded
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
// Update approval status
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
// Get session by run_id
session := &store.Session{
ID: sessionID,
RunID: runID,
Status: store.SessionStatusWaitingInput,
}
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
// Expect status update to running
convStore.EXPECT().
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
if update.Status == nil || *update.Status != store.SessionStatusRunning {
t.Errorf("expected status update to running, got %v", update.Status)
}
return nil
})
return callID
},
expectedStatus: store.SessionStatusRunning,
},
{
name: "deny_updates_session_to_running",
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
callID := "fc-deny-123"
runID := "run-deny-456"
sessionID := "sess-deny-789"
// Function call exists
fc := &humanlayer.FunctionCall{
CallID: callID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
},
}
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(1)
// API denial succeeds
mockClient.EXPECT().DenyFunctionCall(gomock.Any(), callID, "test reason").Return(nil)
// Mark as responded
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
// Update approval status
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusDenied).Return(nil)
// Get session by run_id
session := &store.Session{
ID: sessionID,
RunID: runID,
Status: store.SessionStatusWaitingInput,
}
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
// Expect status update to running
convStore.EXPECT().
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
if update.Status == nil || *update.Status != store.SessionStatusRunning {
t.Errorf("expected status update to running, got %v", update.Status)
}
return nil
})
return callID
},
expectedStatus: store.SessionStatusRunning,
},
{
name: "no_status_update_for_non_waiting_session",
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
callID := "fc-nochange-123"
runID := "run-nochange-456"
sessionID := "sess-nochange-789"
// Function call exists
fc := &humanlayer.FunctionCall{
CallID: callID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "test_function",
},
}
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(2)
// API approval succeeds
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
// Mark as responded
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
// Update approval status
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
// Get session - session is already running, not waiting
session := &store.Session{
ID: sessionID,
RunID: runID,
Status: store.SessionStatusRunning, // Already running
}
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
// No status update should happen since session is not waiting
return callID
},
expectedStatus: store.SessionStatusRunning,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a new controller for each test to isolate expectations
testCtrl := gomock.NewController(t)
defer testCtrl.Finish()
// Create mocks
mockClient := NewMockAPIClient(testCtrl)
mockStore := NewMockStore(testCtrl)
convStore := store.NewMockConversationStore(testCtrl)
// Create manager
manager := &DefaultManager{
Client: mockClient,
Store: mockStore,
ConversationStore: convStore,
}
// Setup test expectations and get callID
callID := tt.setupTest(mockClient, mockStore, convStore)
// Execute based on test type
ctx := context.Background()
var err error
if tt.name == "deny_updates_session_to_running" {
err = manager.DenyFunctionCall(ctx, callID, "test reason")
} else {
err = manager.ApproveFunctionCall(ctx, callID, "test comment")
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Expectations are verified by gomock
})
}
}

View File

@@ -8,34 +8,38 @@ import (
"time"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
// Poller polls the HumanLayer API for pending approvals
type Poller struct {
client APIClient
store Store
eventBus bus.EventBus
interval time.Duration
maxBackoff time.Duration
backoffFactor float64
mu sync.Mutex
cancel context.CancelFunc
failureCount int
client APIClient
store Store
conversationStore store.ConversationStore
eventBus bus.EventBus
interval time.Duration
maxBackoff time.Duration
backoffFactor float64
mu sync.Mutex
cancel context.CancelFunc
failureCount int
}
// NewPoller creates a new approval poller
func NewPoller(client APIClient, store Store, interval time.Duration, eventBus bus.EventBus) *Poller {
func NewPoller(client APIClient, store Store, conversationStore store.ConversationStore, interval time.Duration, eventBus bus.EventBus) *Poller {
if interval <= 0 {
interval = 5 * time.Second
}
return &Poller{
client: client,
store: store,
eventBus: eventBus,
interval: interval,
maxBackoff: 5 * time.Minute,
backoffFactor: 2.0,
failureCount: 0,
client: client,
store: store,
conversationStore: conversationStore,
eventBus: eventBus,
interval: interval,
maxBackoff: 5 * time.Minute,
backoffFactor: 2.0,
failureCount: 0,
}
}
@@ -138,6 +142,11 @@ func (p *Poller) poll(ctx context.Context) {
if err := p.store.StoreFunctionCall(fc); err != nil {
slog.Error("failed to store function call", "call_id", fc.CallID, "error", err)
}
// Correlate with tool calls in the database
if p.conversationStore != nil && fc.RunID != "" {
p.correlateApproval(pollCtx, fc)
}
}
slog.Debug("fetched function calls", "count", len(functionCalls), "new", newCount)
@@ -226,3 +235,62 @@ func (p *Poller) calculateIntervalLocked() time.Duration {
return time.Duration(backoff)
}
// correlateApproval attempts to match an approval with a pending tool call
func (p *Poller) correlateApproval(ctx context.Context, fc humanlayer.FunctionCall) {
// Find the session with this run_id
session, err := p.conversationStore.GetSessionByRunID(ctx, fc.RunID)
if err != nil {
slog.Error("failed to get session for correlation",
"run_id", fc.RunID,
"error", err)
return
}
if session == nil {
// This is expected for approvals that aren't from our Claude sessions
slog.Debug("no matching session for approval",
"run_id", fc.RunID,
"approval_id", fc.CallID)
return
}
toolName := fc.Spec.Fn
// Try to find a pending tool call for this session and tool
toolCall, err := p.conversationStore.GetPendingToolCall(ctx, session.ID, toolName)
if err != nil || toolCall == nil {
slog.Debug("no matching pending tool call for approval",
"session_id", session.ID,
"run_id", fc.RunID,
"tool_name", toolName,
"approval_id", fc.CallID)
return
}
// Found a matching tool call in one of our sessions - correlate it
if err := p.conversationStore.CorrelateApproval(ctx, session.ID, toolName, fc.CallID); err != nil {
slog.Error("failed to correlate approval with tool call",
"approval_id", fc.CallID,
"session_id", session.ID,
"tool_name", toolName,
"error", err)
return
}
slog.Info("correlated approval with tool call",
"approval_id", fc.CallID,
"session_id", session.ID,
"tool_name", toolName,
"run_id", fc.RunID)
// Update session status to waiting_input
waitingStatus := store.SessionStatusWaitingInput
update := store.SessionUpdate{
Status: &waitingStatus,
}
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to update session status to waiting_input",
"session_id", session.ID,
"error", err)
}
}

View File

@@ -134,7 +134,7 @@ func TestPoller_StartStop(t *testing.T) {
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return([]humanlayer.FunctionCall{}, nil).MinTimes(2)
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return([]humanlayer.HumanContact{}, nil).MinTimes(2)
poller := NewPoller(mockClient, mockStore, 50*time.Millisecond, nil)
poller := NewPoller(mockClient, mockStore, nil, 50*time.Millisecond, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

View File

@@ -7,6 +7,7 @@ import (
"net"
"os"
"path/filepath"
"time"
"github.com/humanlayer/humanlayer/hld/approval"
"github.com/humanlayer/humanlayer/hld/bus"
@@ -94,7 +95,7 @@ func New() (*Daemon, error) {
BaseURL: cfg.APIBaseURL,
// Use defaults for now, could add to daemon config later
}
approvalManager, err = approval.NewManager(approvalCfg, eventBus)
approvalManager, err = approval.NewManager(approvalCfg, eventBus, conversationStore)
if err != nil {
return nil, fmt.Errorf("failed to create approval manager: %w", err)
}
@@ -147,6 +148,12 @@ func (d *Daemon) Run(ctx context.Context) error {
// Create and start RPC server
d.rpcServer = rpc.NewServer()
// Mark orphaned sessions as failed (from previous daemon run)
if err := d.markOrphanedSessionsAsFailed(ctx); err != nil {
slog.Warn("failed to mark orphaned sessions as failed", "error", err)
// Don't fail startup for this
}
// Register subscription handlers
subscriptionHandlers := rpc.NewSubscriptionHandlers(d.eventBus)
d.rpcServer.SetSubscriptionHandlers(subscriptionHandlers)
@@ -221,3 +228,49 @@ func (d *Daemon) handleConnection(ctx context.Context, conn net.Conn) {
slog.Debug("client disconnected", "remote", conn.RemoteAddr())
}
// markOrphanedSessionsAsFailed marks any sessions that were running or waiting
// when the daemon restarted as failed
func (d *Daemon) markOrphanedSessionsAsFailed(ctx context.Context) error {
if d.store == nil {
return nil
}
// Get all sessions from the database
sessions, err := d.store.ListSessions(ctx)
if err != nil {
return fmt.Errorf("failed to list sessions: %w", err)
}
orphanedCount := 0
for _, session := range sessions {
// Mark running or waiting sessions as failed
if session.Status == store.SessionStatusRunning ||
session.Status == store.SessionStatusWaitingInput ||
session.Status == store.SessionStatusStarting {
failedStatus := store.SessionStatusFailed
errorMsg := "daemon restarted while session was active"
now := time.Now()
update := store.SessionUpdate{
Status: &failedStatus,
CompletedAt: &now,
ErrorMessage: &errorMsg,
}
if err := d.store.UpdateSession(ctx, session.ID, update); err != nil {
slog.Error("failed to mark orphaned session as failed",
"session_id", session.ID,
"error", err)
// Continue with other sessions
} else {
orphanedCount++
}
}
}
if orphanedCount > 0 {
slog.Info("marked orphaned sessions as failed", "count", orphanedCount)
}
return nil
}

View File

@@ -119,15 +119,23 @@ func TestDaemonApprovalIntegration(t *testing.T) {
},
}
// Create a minimal in-memory store for testing
testStore, err := store.NewSQLiteStore(":memory:")
if err != nil {
t.Fatalf("failed to create test store: %v", err)
}
defer testStore.Close()
// Create real approval components for integration testing
approvalStore := approval.NewMemoryStore()
poller := approval.NewPoller(mockClient, approvalStore, 50*time.Millisecond, nil)
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, nil)
// We need to manually construct the manager with our test client
approvalManager := &approval.DefaultManager{
Client: mockClient,
Store: approvalStore,
Poller: poller,
Client: mockClient,
Store: approvalStore,
Poller: poller,
ConversationStore: testStore,
}
// Create test daemon with approval manager
@@ -141,12 +149,6 @@ func TestDaemonApprovalIntegration(t *testing.T) {
}
// Create session manager (we don't need real sessions for this test)
// Create a minimal in-memory store for testing
testStore, err := store.NewSQLiteStore(":memory:")
if err != nil {
t.Fatalf("failed to create test store: %v", err)
}
defer testStore.Close()
sessionManager, err := session.NewManager(nil, testStore)
if err != nil {

View File

@@ -0,0 +1,92 @@
package daemon
import (
"context"
"testing"
"github.com/humanlayer/humanlayer/hld/store"
"go.uber.org/mock/gomock"
)
func TestDaemon_MarkOrphanedSessions(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockConversationStore(ctrl)
// Set up sessions with various statuses
sessions := []*store.Session{
{
ID: "sess-running",
Status: store.SessionStatusRunning,
},
{
ID: "sess-waiting",
Status: store.SessionStatusWaitingInput,
},
{
ID: "sess-starting",
Status: store.SessionStatusStarting,
},
{
ID: "sess-completed",
Status: store.SessionStatusCompleted, // Should NOT be marked as failed
},
{
ID: "sess-already-failed",
Status: store.SessionStatusFailed, // Should NOT be updated
},
}
// Expect ListSessions to be called
mockStore.EXPECT().ListSessions(gomock.Any()).Return(sessions, nil)
// Expect UpdateSession for orphaned sessions only
for _, sess := range sessions {
if sess.Status == store.SessionStatusRunning ||
sess.Status == store.SessionStatusWaitingInput ||
sess.Status == store.SessionStatusStarting {
mockStore.EXPECT().
UpdateSession(gomock.Any(), sess.ID, gomock.Any()).
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
// Verify the update
if update.Status == nil || *update.Status != store.SessionStatusFailed {
t.Errorf("expected status update to failed, got %v", update.Status)
}
if update.ErrorMessage == nil || *update.ErrorMessage != "daemon restarted while session was active" {
t.Errorf("expected error message about daemon restart, got %v", update.ErrorMessage)
}
if update.CompletedAt == nil {
t.Error("expected CompletedAt to be set")
}
return nil
})
}
}
// Create daemon with mock store
d := &Daemon{
store: mockStore,
}
// Call markOrphanedSessionsAsFailed
err := d.markOrphanedSessionsAsFailed(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Expectations are verified by gomock
}
func TestDaemon_MarkOrphanedSessions_NoStore(t *testing.T) {
// Test that it handles nil store gracefully
d := &Daemon{
store: nil,
}
err := d.markOrphanedSessionsAsFailed(context.Background())
if err != nil {
t.Fatalf("expected no error with nil store, got: %v", err)
}
}

View File

@@ -0,0 +1,288 @@
//go:build integration
package daemon
import (
"context"
"fmt"
"testing"
"time"
"github.com/humanlayer/humanlayer/hld/approval"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/client"
"github.com/humanlayer/humanlayer/hld/config"
"github.com/humanlayer/humanlayer/hld/internal/testutil"
"github.com/humanlayer/humanlayer/hld/session"
"github.com/humanlayer/humanlayer/hld/store"
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
)
// TestSessionStateTransitionsIntegration tests the full flow of session state changes
// when approvals are created and resolved
func TestSessionStateTransitionsIntegration(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Create temporary socket path for test
socketPath := testutil.SocketPath(t, "session-state")
// Create in-memory store
testStore, err := store.NewSQLiteStore(":memory:")
if err != nil {
t.Fatalf("failed to create test store: %v", err)
}
defer testStore.Close()
// Create event bus
eventBus := bus.NewEventBus()
// Create mock API client
mockClient := &mockSessionStateAPIClient{
functionCalls: make(map[string]*humanlayer.FunctionCall),
}
// Create real approval components
approvalStore := approval.NewMemoryStore()
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, eventBus)
// Create approval manager
approvalManager := &approval.DefaultManager{
Client: mockClient,
Store: approvalStore,
Poller: poller,
EventBus: eventBus,
ConversationStore: testStore,
}
// Create session manager
sessionManager, err := session.NewManager(eventBus, testStore)
if err != nil {
t.Fatalf("failed to create session manager: %v", err)
}
// Create daemon
d := &Daemon{
config: &config.Config{
SocketPath: socketPath,
APIKey: "test-key",
},
socketPath: socketPath,
sessions: sessionManager,
approvals: approvalManager,
eventBus: eventBus,
store: testStore,
}
// Start daemon
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
doneCh := make(chan error)
go func() {
doneCh <- d.Run(ctx)
}()
// Give daemon time to start
time.Sleep(100 * time.Millisecond)
// Create client
c, err := client.New(socketPath)
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
defer c.Close()
// Test scenario: Create session with tool call that needs approval
t.Run("session_state_transitions", func(t *testing.T) {
// 1. Create a session manually in the database
sessionID := "test-session-001"
runID := "test-run-001"
claudeSessionID := "claude-sess-001"
ctx := context.Background()
session := &store.Session{
ID: sessionID,
RunID: runID,
ClaudeSessionID: claudeSessionID,
Query: "Test query",
Status: store.SessionStatusRunning,
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
}
if err := testStore.CreateSession(ctx, session); err != nil {
t.Fatalf("failed to create session: %v", err)
}
// 2. Add a tool call that needs approval
toolCall := &store.ConversationEvent{
SessionID: sessionID,
ClaudeSessionID: claudeSessionID,
Sequence: 1,
EventType: store.EventTypeToolCall,
ToolID: "tool-001",
ToolName: "dangerous_function",
ToolInputJSON: `{"action": "delete_all"}`,
CreatedAt: time.Now(),
}
if err := testStore.AddConversationEvent(ctx, toolCall); err != nil {
t.Fatalf("failed to add tool call: %v", err)
}
// 3. Simulate an approval coming from HumanLayer API
approvalID := "approval-001"
mockClient.AddFunctionCall(humanlayer.FunctionCall{
CallID: approvalID,
RunID: runID,
Spec: humanlayer.FunctionCallSpec{
Fn: "dangerous_function",
Kwargs: map[string]interface{}{
"action": "delete_all",
},
},
})
// 4. Wait for poller to pick up the approval and correlate it
time.Sleep(150 * time.Millisecond)
// 5. Check that session status changed to waiting_input
updatedSession, err := testStore.GetSession(ctx, sessionID)
if err != nil {
t.Fatalf("failed to get session: %v", err)
}
if updatedSession.Status != store.SessionStatusWaitingInput {
t.Errorf("expected session status to be waiting_input, got %s", updatedSession.Status)
}
// 6. Check that approval was correlated
conversation, err := testStore.GetConversation(ctx, claudeSessionID)
if err != nil {
t.Fatalf("failed to get conversation: %v", err)
}
var correlatedEvent *store.ConversationEvent
for _, event := range conversation {
if event.EventType == store.EventTypeToolCall && event.ToolName == "dangerous_function" {
correlatedEvent = event
break
}
}
if correlatedEvent == nil {
t.Fatal("tool call event not found")
}
if correlatedEvent.ApprovalStatus != store.ApprovalStatusPending {
t.Errorf("expected approval status to be pending, got %s", correlatedEvent.ApprovalStatus)
}
if correlatedEvent.ApprovalID != approvalID {
t.Errorf("expected approval ID to be %s, got %s", approvalID, correlatedEvent.ApprovalID)
}
// 7. Approve the function call via client
if err := c.SendDecision(approvalID, "function_call", "approve", "Approved for testing"); err != nil {
t.Fatalf("failed to send approval: %v", err)
}
// Give time for status update
time.Sleep(100 * time.Millisecond)
// 8. Check that session status changed back to running
finalSession, err := testStore.GetSession(ctx, sessionID)
if err != nil {
t.Fatalf("failed to get final session: %v", err)
}
if finalSession.Status != store.SessionStatusRunning {
t.Errorf("expected session status to be running after approval, got %s", finalSession.Status)
}
// 9. Check that approval status was updated
finalConversation, err := testStore.GetConversation(ctx, claudeSessionID)
if err != nil {
t.Fatalf("failed to get final conversation: %v", err)
}
var finalEvent *store.ConversationEvent
for _, event := range finalConversation {
if event.EventType == store.EventTypeToolCall && event.ToolName == "dangerous_function" {
finalEvent = event
break
}
}
if finalEvent == nil {
t.Fatal("tool call event not found in final conversation")
}
if finalEvent.ApprovalStatus != store.ApprovalStatusApproved {
t.Errorf("expected approval status to be approved, got %s", finalEvent.ApprovalStatus)
}
})
// Cleanup
cancel()
select {
case err := <-doneCh:
if err != nil {
t.Errorf("daemon returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("daemon did not shut down in time")
}
}
// Mock API client for session state testing
type mockSessionStateAPIClient struct {
functionCalls map[string]*humanlayer.FunctionCall
}
func (m *mockSessionStateAPIClient) AddFunctionCall(fc humanlayer.FunctionCall) {
m.functionCalls[fc.CallID] = &fc
}
func (m *mockSessionStateAPIClient) GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error) {
var result []humanlayer.FunctionCall
for _, fc := range m.functionCalls {
if fc.Status == nil || fc.Status.RespondedAt == nil {
result = append(result, *fc)
}
}
return result, nil
}
func (m *mockSessionStateAPIClient) GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error) {
return []humanlayer.HumanContact{}, nil
}
func (m *mockSessionStateAPIClient) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
if fc, ok := m.functionCalls[callID]; ok {
if fc.Status == nil {
fc.Status = &humanlayer.FunctionCallStatus{}
}
now := humanlayer.CustomTime{Time: time.Now()}
fc.Status.RespondedAt = &now
approved := true
fc.Status.Approved = &approved
fc.Status.Comment = comment
return nil
}
return fmt.Errorf("function call not found: %s", callID)
}
func (m *mockSessionStateAPIClient) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
if fc, ok := m.functionCalls[callID]; ok {
if fc.Status == nil {
fc.Status = &humanlayer.FunctionCallStatus{}
}
now := humanlayer.CustomTime{Time: time.Now()}
fc.Status.RespondedAt = &now
approved := false
fc.Status.Approved = &approved
fc.Status.Comment = reason
return nil
}
return fmt.Errorf("function call not found: %s", callID)
}
func (m *mockSessionStateAPIClient) RespondToHumanContact(ctx context.Context, callID string, response string) error {
return fmt.Errorf("not implemented")
}

View File

@@ -20,10 +20,12 @@ type SQLiteStore struct {
// NewSQLiteStore creates a new SQLite-backed store
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
// Ensure directory exists
dbDir := filepath.Dir(dbPath)
if err := os.MkdirAll(dbDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
// Ensure directory exists (skip for in-memory databases)
if dbPath != ":memory:" {
dbDir := filepath.Dir(dbPath)
if err := os.MkdirAll(dbDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
}
// Open database
@@ -308,6 +310,64 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
return &session, nil
}
// GetSessionByRunID retrieves a session by its run_id
func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Session, error) {
query := `
SELECT id, run_id, claude_session_id, parent_session_id,
query, model, working_dir, max_turns, system_prompt, custom_instructions,
status, created_at, last_activity_at, completed_at,
cost_usd, total_tokens, duration_ms, error_message
FROM sessions
WHERE run_id = ?
`
var session Session
var claudeSessionID, parentSessionID, model, workingDir, systemPrompt, customInstructions sql.NullString
var completedAt sql.NullTime
var costUSD sql.NullFloat64
var totalTokens, durationMS sql.NullInt64
var errorMessage sql.NullString
err := s.db.QueryRowContext(ctx, query, runID).Scan(
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
&session.Query, &model, &workingDir, &session.MaxTurns,
&systemPrompt, &customInstructions,
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
&costUSD, &totalTokens, &durationMS, &errorMessage,
)
if err == sql.ErrNoRows {
return nil, nil // No session found
}
if err != nil {
return nil, fmt.Errorf("failed to get session by run_id: %w", err)
}
// Convert nullable fields
session.ClaudeSessionID = claudeSessionID.String
session.ParentSessionID = parentSessionID.String
session.Model = model.String
session.WorkingDir = workingDir.String
session.SystemPrompt = systemPrompt.String
session.CustomInstructions = customInstructions.String
session.ErrorMessage = errorMessage.String
if completedAt.Valid {
session.CompletedAt = &completedAt.Time
}
if costUSD.Valid {
session.CostUSD = &costUSD.Float64
}
if totalTokens.Valid {
tokens := int(totalTokens.Int64)
session.TotalTokens = &tokens
}
if durationMS.Valid {
duration := int(durationMS.Int64)
session.DurationMS = &duration
}
return &session, nil
}
// ListSessions retrieves all sessions
func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
query := `

View File

@@ -13,6 +13,7 @@ type ConversationStore interface {
CreateSession(ctx context.Context, session *Session) error
UpdateSession(ctx context.Context, sessionID string, updates SessionUpdate) error
GetSession(ctx context.Context, sessionID string) (*Session, error)
GetSessionByRunID(ctx context.Context, runID string) (*Session, error)
ListSessions(ctx context.Context) ([]*Session, error)
// Conversation operations