mirror of
https://github.com/humanlayer/humanlayer.git
synced 2025-08-20 19:01:22 +03:00
Correlate Approvals with Tool Calls (#206)
* initial phase 3 implementation * fix(approval): correlate approvals by looking up session ID from run_id - Fix GetPendingToolCall call to use session ID instead of run_id - Add session lookup by run_id before attempting correlation - Improve debug logging for correlation process * fix(store): handle in-memory SQLite databases correctly - Skip directory creation for :memory: databases - Prevents error when using in-memory databases for testing * test: add comprehensive tests for approval correlation - Unit tests for approval correlation logic - Unit tests for session status transitions - Integration test for orphaned session cleanup - Integration test for full approval flow with session state changes * ignore go warnings in commands * switch to getting session by run id (way better)
This commit is contained in:
@@ -9,15 +9,15 @@ test: test-unit test-integration
|
||||
|
||||
# Run unit tests with race detection
|
||||
test-unit:
|
||||
go test -v -race ./...
|
||||
CGO_LDFLAGS="-Wl,-w" go test -v -race ./...
|
||||
|
||||
# Run integration tests (requires build tag)
|
||||
test-integration: build
|
||||
go test -v -tags=integration -run Integration ./daemon/...
|
||||
CGO_LDFLAGS="-Wl,-w" go test -v -tags=integration -run Integration ./daemon/...
|
||||
|
||||
# Run tests with race detection
|
||||
test-race:
|
||||
go test -race ./...
|
||||
CGO_LDFLAGS="-Wl,-w" go test -race ./...
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
|
||||
157
hld/approval/correlation_test.go
Normal file
157
hld/approval/correlation_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestPoller_CorrelateApproval(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
functionCall humanlayer.FunctionCall
|
||||
setupStore func(mockStore *store.MockConversationStore)
|
||||
expectCorrelate bool
|
||||
expectStatusUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "successful_correlation",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-123",
|
||||
RunID: "run-456",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "dangerous_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// Expect GetSessionByRunID to find the session
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "run-456").
|
||||
Return(&store.Session{
|
||||
ID: "sess-789",
|
||||
RunID: "run-456",
|
||||
}, nil)
|
||||
|
||||
// Return a pending tool call
|
||||
mockStore.EXPECT().
|
||||
GetPendingToolCall(gomock.Any(), "sess-789", "dangerous_function").
|
||||
Return(&store.ConversationEvent{
|
||||
ID: 1,
|
||||
SessionID: "sess-789",
|
||||
ToolName: "dangerous_function",
|
||||
}, nil)
|
||||
|
||||
// Expect correlation
|
||||
mockStore.EXPECT().
|
||||
CorrelateApproval(gomock.Any(), "sess-789", "dangerous_function", "fc-123").
|
||||
Return(nil)
|
||||
|
||||
// Expect status update to waiting_input
|
||||
mockStore.EXPECT().
|
||||
UpdateSession(gomock.Any(), "sess-789", gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, sessionID string, update store.SessionUpdate) error {
|
||||
if update.Status == nil || *update.Status != store.SessionStatusWaitingInput {
|
||||
t.Errorf("expected status update to waiting_input, got %v", update.Status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
},
|
||||
expectCorrelate: true,
|
||||
expectStatusUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "no_matching_session",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-999",
|
||||
RunID: "unknown-run",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "some_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// GetSessionByRunID returns nil (no matching session)
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "unknown-run").
|
||||
Return(nil, nil)
|
||||
|
||||
// No further calls should happen
|
||||
},
|
||||
expectCorrelate: false,
|
||||
expectStatusUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "nil_tool_call_returned",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-nil",
|
||||
RunID: "run-nil",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "nil_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// GetSessionByRunID returns a matching session
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "run-nil").
|
||||
Return(&store.Session{
|
||||
ID: "sess-nil",
|
||||
RunID: "run-nil",
|
||||
}, nil)
|
||||
|
||||
// Return nil tool call (no error but no result)
|
||||
mockStore.EXPECT().
|
||||
GetPendingToolCall(gomock.Any(), "sess-nil", "nil_function").
|
||||
Return(nil, nil)
|
||||
|
||||
// No correlation or status update should happen
|
||||
},
|
||||
expectCorrelate: false,
|
||||
expectStatusUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "get_session_error",
|
||||
functionCall: humanlayer.FunctionCall{
|
||||
CallID: "fc-error",
|
||||
RunID: "run-error",
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "error_function",
|
||||
},
|
||||
},
|
||||
setupStore: func(mockStore *store.MockConversationStore) {
|
||||
// GetSessionByRunID returns an error
|
||||
mockStore.EXPECT().
|
||||
GetSessionByRunID(gomock.Any(), "run-error").
|
||||
Return(nil, fmt.Errorf("database error"))
|
||||
|
||||
// No further calls should happen
|
||||
},
|
||||
expectCorrelate: false,
|
||||
expectStatusUpdate: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock store
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
tt.setupStore(mockStore)
|
||||
|
||||
// Create poller with mock store
|
||||
poller := &Poller{
|
||||
conversationStore: mockStore,
|
||||
}
|
||||
|
||||
// Call correlateApproval
|
||||
ctx := context.Background()
|
||||
poller.correlateApproval(ctx, tt.functionCall)
|
||||
|
||||
// The expectations are verified by gomock
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,11 @@ package approval
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
@@ -20,14 +22,15 @@ type Config struct {
|
||||
|
||||
// DefaultManager is the default implementation of Manager
|
||||
type DefaultManager struct {
|
||||
Client APIClient
|
||||
Store Store
|
||||
Poller *Poller
|
||||
EventBus bus.EventBus
|
||||
Client APIClient
|
||||
Store Store
|
||||
Poller *Poller
|
||||
EventBus bus.EventBus
|
||||
ConversationStore store.ConversationStore
|
||||
}
|
||||
|
||||
// NewManager creates a new approval manager
|
||||
func NewManager(cfg Config, eventBus bus.EventBus) (Manager, error) {
|
||||
func NewManager(cfg Config, eventBus bus.EventBus, conversationStore store.ConversationStore) (Manager, error) {
|
||||
// Set defaults
|
||||
if cfg.PollInterval <= 0 {
|
||||
cfg.PollInterval = 5 * time.Second
|
||||
@@ -58,15 +61,16 @@ func NewManager(cfg Config, eventBus bus.EventBus) (Manager, error) {
|
||||
store := NewMemoryStore()
|
||||
|
||||
// Create poller with configured interval
|
||||
poller := NewPoller(client, store, cfg.PollInterval, eventBus)
|
||||
poller := NewPoller(client, store, conversationStore, cfg.PollInterval, eventBus)
|
||||
poller.maxBackoff = cfg.MaxBackoff
|
||||
poller.backoffFactor = cfg.BackoffFactor
|
||||
|
||||
return &DefaultManager{
|
||||
Client: client,
|
||||
Store: store,
|
||||
Poller: poller,
|
||||
EventBus: eventBus,
|
||||
Client: client,
|
||||
Store: store,
|
||||
Poller: poller,
|
||||
EventBus: eventBus,
|
||||
ConversationStore: conversationStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -82,10 +86,16 @@ func (m *DefaultManager) Stop() {
|
||||
|
||||
// GetPendingApprovals returns all pending approvals, optionally filtered by session
|
||||
func (m *DefaultManager) GetPendingApprovals(sessionID string) ([]PendingApproval, error) {
|
||||
if sessionID != "" {
|
||||
// In the future, we'll need to look up the run_id for this session
|
||||
// For now, return empty since we don't have session->run_id mapping yet
|
||||
return []PendingApproval{}, nil
|
||||
if sessionID != "" && m.ConversationStore != nil {
|
||||
// Look up the session to get its run_id
|
||||
ctx := context.Background()
|
||||
session, err := m.ConversationStore.GetSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
slog.Debug("session not found for approval filter", "session_id", sessionID, "error", err)
|
||||
return []PendingApproval{}, nil
|
||||
}
|
||||
// Get approvals by run_id
|
||||
return m.Store.GetPendingByRunID(session.RunID)
|
||||
}
|
||||
return m.Store.GetAllPending()
|
||||
}
|
||||
@@ -113,9 +123,40 @@ func (m *DefaultManager) ApproveFunctionCall(ctx context.Context, callID string,
|
||||
return fmt.Errorf("failed to update local state: %w", err)
|
||||
}
|
||||
|
||||
// Get the function call to access run_id
|
||||
fc, _ := m.Store.GetFunctionCall(callID)
|
||||
|
||||
// Update approval status in database if we have a conversation store
|
||||
if m.ConversationStore != nil && fc != nil {
|
||||
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusApproved); err != nil {
|
||||
slog.Error("failed to update approval status in database", "error", err)
|
||||
// Don't fail the whole operation for this
|
||||
}
|
||||
|
||||
// Update session status back to running since approval is resolved
|
||||
if fc.RunID != "" {
|
||||
// Look up the session by run_id
|
||||
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
|
||||
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
|
||||
runningStatus := store.SessionStatusRunning
|
||||
update := store.SessionUpdate{
|
||||
Status: &runningStatus,
|
||||
}
|
||||
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status to running",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
} else {
|
||||
slog.Info("updated session status back to running after approval",
|
||||
"session_id", session.ID,
|
||||
"approval_id", callID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if m.EventBus != nil {
|
||||
fc, _ := m.Store.GetFunctionCall(callID)
|
||||
if m.EventBus != nil && fc != nil {
|
||||
m.EventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Data: map[string]interface{}{
|
||||
@@ -134,7 +175,7 @@ func (m *DefaultManager) ApproveFunctionCall(ctx context.Context, callID string,
|
||||
// DenyFunctionCall denies a function call
|
||||
func (m *DefaultManager) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
|
||||
// First check if we have this function call
|
||||
_, err := m.Store.GetFunctionCall(callID)
|
||||
fc, err := m.Store.GetFunctionCall(callID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("function call not found: %w", err)
|
||||
}
|
||||
@@ -149,9 +190,37 @@ func (m *DefaultManager) DenyFunctionCall(ctx context.Context, callID string, re
|
||||
return fmt.Errorf("failed to update local state: %w", err)
|
||||
}
|
||||
|
||||
// Update approval status in database if we have a conversation store
|
||||
if m.ConversationStore != nil && fc != nil {
|
||||
if err := m.ConversationStore.UpdateApprovalStatus(ctx, callID, store.ApprovalStatusDenied); err != nil {
|
||||
slog.Error("failed to update approval status in database", "error", err)
|
||||
// Don't fail the whole operation for this
|
||||
}
|
||||
|
||||
// Update session status back to running since approval is resolved (denied)
|
||||
if fc.RunID != "" {
|
||||
// Look up the session by run_id
|
||||
session, err := m.ConversationStore.GetSessionByRunID(ctx, fc.RunID)
|
||||
if err == nil && session != nil && session.Status == store.SessionStatusWaitingInput {
|
||||
runningStatus := store.SessionStatusRunning
|
||||
update := store.SessionUpdate{
|
||||
Status: &runningStatus,
|
||||
}
|
||||
if err := m.ConversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status to running",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
} else {
|
||||
slog.Info("updated session status back to running after denial",
|
||||
"session_id", session.ID,
|
||||
"approval_id", callID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Publish event
|
||||
if m.EventBus != nil {
|
||||
fc, _ := m.Store.GetFunctionCall(callID)
|
||||
if m.EventBus != nil && fc != nil {
|
||||
m.EventBus.Publish(bus.Event{
|
||||
Type: bus.EventApprovalResolved,
|
||||
Data: map[string]interface{}{
|
||||
|
||||
196
hld/approval/manager_status_test.go
Normal file
196
hld/approval/manager_status_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestManager_SessionStatusTransitions(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupTest func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string // returns callID
|
||||
expectedStatus string
|
||||
}{
|
||||
{
|
||||
name: "approve_updates_session_to_running",
|
||||
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
|
||||
callID := "fc-approve-123"
|
||||
runID := "run-approve-456"
|
||||
sessionID := "sess-approve-789"
|
||||
|
||||
// Function call exists
|
||||
fc := &humanlayer.FunctionCall{
|
||||
CallID: callID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
},
|
||||
}
|
||||
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).AnyTimes()
|
||||
|
||||
// API approval succeeds
|
||||
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
|
||||
|
||||
// Mark as responded
|
||||
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
|
||||
|
||||
// Update approval status
|
||||
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
|
||||
|
||||
// Get session by run_id
|
||||
session := &store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
Status: store.SessionStatusWaitingInput,
|
||||
}
|
||||
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
|
||||
|
||||
// Expect status update to running
|
||||
convStore.EXPECT().
|
||||
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
|
||||
if update.Status == nil || *update.Status != store.SessionStatusRunning {
|
||||
t.Errorf("expected status update to running, got %v", update.Status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return callID
|
||||
},
|
||||
expectedStatus: store.SessionStatusRunning,
|
||||
},
|
||||
{
|
||||
name: "deny_updates_session_to_running",
|
||||
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
|
||||
callID := "fc-deny-123"
|
||||
runID := "run-deny-456"
|
||||
sessionID := "sess-deny-789"
|
||||
|
||||
// Function call exists
|
||||
fc := &humanlayer.FunctionCall{
|
||||
CallID: callID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
},
|
||||
}
|
||||
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(1)
|
||||
|
||||
// API denial succeeds
|
||||
mockClient.EXPECT().DenyFunctionCall(gomock.Any(), callID, "test reason").Return(nil)
|
||||
|
||||
// Mark as responded
|
||||
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
|
||||
|
||||
// Update approval status
|
||||
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusDenied).Return(nil)
|
||||
|
||||
// Get session by run_id
|
||||
session := &store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
Status: store.SessionStatusWaitingInput,
|
||||
}
|
||||
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
|
||||
|
||||
// Expect status update to running
|
||||
convStore.EXPECT().
|
||||
UpdateSession(gomock.Any(), sessionID, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
|
||||
if update.Status == nil || *update.Status != store.SessionStatusRunning {
|
||||
t.Errorf("expected status update to running, got %v", update.Status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return callID
|
||||
},
|
||||
expectedStatus: store.SessionStatusRunning,
|
||||
},
|
||||
{
|
||||
name: "no_status_update_for_non_waiting_session",
|
||||
setupTest: func(mockClient *MockAPIClient, mockStore *MockStore, convStore *store.MockConversationStore) string {
|
||||
callID := "fc-nochange-123"
|
||||
runID := "run-nochange-456"
|
||||
sessionID := "sess-nochange-789"
|
||||
|
||||
// Function call exists
|
||||
fc := &humanlayer.FunctionCall{
|
||||
CallID: callID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "test_function",
|
||||
},
|
||||
}
|
||||
mockStore.EXPECT().GetFunctionCall(callID).Return(fc, nil).Times(2)
|
||||
|
||||
// API approval succeeds
|
||||
mockClient.EXPECT().ApproveFunctionCall(gomock.Any(), callID, "test comment").Return(nil)
|
||||
|
||||
// Mark as responded
|
||||
mockStore.EXPECT().MarkFunctionCallResponded(callID).Return(nil)
|
||||
|
||||
// Update approval status
|
||||
convStore.EXPECT().UpdateApprovalStatus(gomock.Any(), callID, store.ApprovalStatusApproved).Return(nil)
|
||||
|
||||
// Get session - session is already running, not waiting
|
||||
session := &store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
Status: store.SessionStatusRunning, // Already running
|
||||
}
|
||||
convStore.EXPECT().GetSessionByRunID(gomock.Any(), runID).Return(session, nil)
|
||||
|
||||
// No status update should happen since session is not waiting
|
||||
|
||||
return callID
|
||||
},
|
||||
expectedStatus: store.SessionStatusRunning,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a new controller for each test to isolate expectations
|
||||
testCtrl := gomock.NewController(t)
|
||||
defer testCtrl.Finish()
|
||||
|
||||
// Create mocks
|
||||
mockClient := NewMockAPIClient(testCtrl)
|
||||
mockStore := NewMockStore(testCtrl)
|
||||
convStore := store.NewMockConversationStore(testCtrl)
|
||||
|
||||
// Create manager
|
||||
manager := &DefaultManager{
|
||||
Client: mockClient,
|
||||
Store: mockStore,
|
||||
ConversationStore: convStore,
|
||||
}
|
||||
|
||||
// Setup test expectations and get callID
|
||||
callID := tt.setupTest(mockClient, mockStore, convStore)
|
||||
|
||||
// Execute based on test type
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
if tt.name == "deny_updates_session_to_running" {
|
||||
err = manager.DenyFunctionCall(ctx, callID, "test reason")
|
||||
} else {
|
||||
err = manager.ApproveFunctionCall(ctx, callID, "test comment")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Expectations are verified by gomock
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,34 +8,38 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
// Poller polls the HumanLayer API for pending approvals
|
||||
type Poller struct {
|
||||
client APIClient
|
||||
store Store
|
||||
eventBus bus.EventBus
|
||||
interval time.Duration
|
||||
maxBackoff time.Duration
|
||||
backoffFactor float64
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
failureCount int
|
||||
client APIClient
|
||||
store Store
|
||||
conversationStore store.ConversationStore
|
||||
eventBus bus.EventBus
|
||||
interval time.Duration
|
||||
maxBackoff time.Duration
|
||||
backoffFactor float64
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
failureCount int
|
||||
}
|
||||
|
||||
// NewPoller creates a new approval poller
|
||||
func NewPoller(client APIClient, store Store, interval time.Duration, eventBus bus.EventBus) *Poller {
|
||||
func NewPoller(client APIClient, store Store, conversationStore store.ConversationStore, interval time.Duration, eventBus bus.EventBus) *Poller {
|
||||
if interval <= 0 {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
return &Poller{
|
||||
client: client,
|
||||
store: store,
|
||||
eventBus: eventBus,
|
||||
interval: interval,
|
||||
maxBackoff: 5 * time.Minute,
|
||||
backoffFactor: 2.0,
|
||||
failureCount: 0,
|
||||
client: client,
|
||||
store: store,
|
||||
conversationStore: conversationStore,
|
||||
eventBus: eventBus,
|
||||
interval: interval,
|
||||
maxBackoff: 5 * time.Minute,
|
||||
backoffFactor: 2.0,
|
||||
failureCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,6 +142,11 @@ func (p *Poller) poll(ctx context.Context) {
|
||||
if err := p.store.StoreFunctionCall(fc); err != nil {
|
||||
slog.Error("failed to store function call", "call_id", fc.CallID, "error", err)
|
||||
}
|
||||
|
||||
// Correlate with tool calls in the database
|
||||
if p.conversationStore != nil && fc.RunID != "" {
|
||||
p.correlateApproval(pollCtx, fc)
|
||||
}
|
||||
}
|
||||
slog.Debug("fetched function calls", "count", len(functionCalls), "new", newCount)
|
||||
|
||||
@@ -226,3 +235,62 @@ func (p *Poller) calculateIntervalLocked() time.Duration {
|
||||
|
||||
return time.Duration(backoff)
|
||||
}
|
||||
|
||||
// correlateApproval attempts to match an approval with a pending tool call
|
||||
func (p *Poller) correlateApproval(ctx context.Context, fc humanlayer.FunctionCall) {
|
||||
// Find the session with this run_id
|
||||
session, err := p.conversationStore.GetSessionByRunID(ctx, fc.RunID)
|
||||
if err != nil {
|
||||
slog.Error("failed to get session for correlation",
|
||||
"run_id", fc.RunID,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
// This is expected for approvals that aren't from our Claude sessions
|
||||
slog.Debug("no matching session for approval",
|
||||
"run_id", fc.RunID,
|
||||
"approval_id", fc.CallID)
|
||||
return
|
||||
}
|
||||
|
||||
toolName := fc.Spec.Fn
|
||||
// Try to find a pending tool call for this session and tool
|
||||
toolCall, err := p.conversationStore.GetPendingToolCall(ctx, session.ID, toolName)
|
||||
if err != nil || toolCall == nil {
|
||||
slog.Debug("no matching pending tool call for approval",
|
||||
"session_id", session.ID,
|
||||
"run_id", fc.RunID,
|
||||
"tool_name", toolName,
|
||||
"approval_id", fc.CallID)
|
||||
return
|
||||
}
|
||||
|
||||
// Found a matching tool call in one of our sessions - correlate it
|
||||
if err := p.conversationStore.CorrelateApproval(ctx, session.ID, toolName, fc.CallID); err != nil {
|
||||
slog.Error("failed to correlate approval with tool call",
|
||||
"approval_id", fc.CallID,
|
||||
"session_id", session.ID,
|
||||
"tool_name", toolName,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("correlated approval with tool call",
|
||||
"approval_id", fc.CallID,
|
||||
"session_id", session.ID,
|
||||
"tool_name", toolName,
|
||||
"run_id", fc.RunID)
|
||||
|
||||
// Update session status to waiting_input
|
||||
waitingStatus := store.SessionStatusWaitingInput
|
||||
update := store.SessionUpdate{
|
||||
Status: &waitingStatus,
|
||||
}
|
||||
if err := p.conversationStore.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to update session status to waiting_input",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestPoller_StartStop(t *testing.T) {
|
||||
mockClient.EXPECT().GetPendingFunctionCalls(gomock.Any()).Return([]humanlayer.FunctionCall{}, nil).MinTimes(2)
|
||||
mockClient.EXPECT().GetPendingHumanContacts(gomock.Any()).Return([]humanlayer.HumanContact{}, nil).MinTimes(2)
|
||||
|
||||
poller := NewPoller(mockClient, mockStore, 50*time.Millisecond, nil)
|
||||
poller := NewPoller(mockClient, mockStore, nil, 50*time.Millisecond, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
@@ -94,7 +95,7 @@ func New() (*Daemon, error) {
|
||||
BaseURL: cfg.APIBaseURL,
|
||||
// Use defaults for now, could add to daemon config later
|
||||
}
|
||||
approvalManager, err = approval.NewManager(approvalCfg, eventBus)
|
||||
approvalManager, err = approval.NewManager(approvalCfg, eventBus, conversationStore)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create approval manager: %w", err)
|
||||
}
|
||||
@@ -147,6 +148,12 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||
// Create and start RPC server
|
||||
d.rpcServer = rpc.NewServer()
|
||||
|
||||
// Mark orphaned sessions as failed (from previous daemon run)
|
||||
if err := d.markOrphanedSessionsAsFailed(ctx); err != nil {
|
||||
slog.Warn("failed to mark orphaned sessions as failed", "error", err)
|
||||
// Don't fail startup for this
|
||||
}
|
||||
|
||||
// Register subscription handlers
|
||||
subscriptionHandlers := rpc.NewSubscriptionHandlers(d.eventBus)
|
||||
d.rpcServer.SetSubscriptionHandlers(subscriptionHandlers)
|
||||
@@ -221,3 +228,49 @@ func (d *Daemon) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
|
||||
slog.Debug("client disconnected", "remote", conn.RemoteAddr())
|
||||
}
|
||||
|
||||
// markOrphanedSessionsAsFailed marks any sessions that were running or waiting
|
||||
// when the daemon restarted as failed
|
||||
func (d *Daemon) markOrphanedSessionsAsFailed(ctx context.Context) error {
|
||||
if d.store == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get all sessions from the database
|
||||
sessions, err := d.store.ListSessions(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list sessions: %w", err)
|
||||
}
|
||||
|
||||
orphanedCount := 0
|
||||
for _, session := range sessions {
|
||||
// Mark running or waiting sessions as failed
|
||||
if session.Status == store.SessionStatusRunning ||
|
||||
session.Status == store.SessionStatusWaitingInput ||
|
||||
session.Status == store.SessionStatusStarting {
|
||||
failedStatus := store.SessionStatusFailed
|
||||
errorMsg := "daemon restarted while session was active"
|
||||
now := time.Now()
|
||||
update := store.SessionUpdate{
|
||||
Status: &failedStatus,
|
||||
CompletedAt: &now,
|
||||
ErrorMessage: &errorMsg,
|
||||
}
|
||||
|
||||
if err := d.store.UpdateSession(ctx, session.ID, update); err != nil {
|
||||
slog.Error("failed to mark orphaned session as failed",
|
||||
"session_id", session.ID,
|
||||
"error", err)
|
||||
// Continue with other sessions
|
||||
} else {
|
||||
orphanedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if orphanedCount > 0 {
|
||||
slog.Info("marked orphaned sessions as failed", "count", orphanedCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -119,15 +119,23 @@ func TestDaemonApprovalIntegration(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Create a minimal in-memory store for testing
|
||||
testStore, err := store.NewSQLiteStore(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test store: %v", err)
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Create real approval components for integration testing
|
||||
approvalStore := approval.NewMemoryStore()
|
||||
poller := approval.NewPoller(mockClient, approvalStore, 50*time.Millisecond, nil)
|
||||
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, nil)
|
||||
|
||||
// We need to manually construct the manager with our test client
|
||||
approvalManager := &approval.DefaultManager{
|
||||
Client: mockClient,
|
||||
Store: approvalStore,
|
||||
Poller: poller,
|
||||
Client: mockClient,
|
||||
Store: approvalStore,
|
||||
Poller: poller,
|
||||
ConversationStore: testStore,
|
||||
}
|
||||
|
||||
// Create test daemon with approval manager
|
||||
@@ -141,12 +149,6 @@ func TestDaemonApprovalIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create session manager (we don't need real sessions for this test)
|
||||
// Create a minimal in-memory store for testing
|
||||
testStore, err := store.NewSQLiteStore(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test store: %v", err)
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
sessionManager, err := session.NewManager(nil, testStore)
|
||||
if err != nil {
|
||||
|
||||
92
hld/daemon/daemon_orphan_test.go
Normal file
92
hld/daemon/daemon_orphan_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestDaemon_MarkOrphanedSessions(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
|
||||
// Set up sessions with various statuses
|
||||
sessions := []*store.Session{
|
||||
{
|
||||
ID: "sess-running",
|
||||
Status: store.SessionStatusRunning,
|
||||
},
|
||||
{
|
||||
ID: "sess-waiting",
|
||||
Status: store.SessionStatusWaitingInput,
|
||||
},
|
||||
{
|
||||
ID: "sess-starting",
|
||||
Status: store.SessionStatusStarting,
|
||||
},
|
||||
{
|
||||
ID: "sess-completed",
|
||||
Status: store.SessionStatusCompleted, // Should NOT be marked as failed
|
||||
},
|
||||
{
|
||||
ID: "sess-already-failed",
|
||||
Status: store.SessionStatusFailed, // Should NOT be updated
|
||||
},
|
||||
}
|
||||
|
||||
// Expect ListSessions to be called
|
||||
mockStore.EXPECT().ListSessions(gomock.Any()).Return(sessions, nil)
|
||||
|
||||
// Expect UpdateSession for orphaned sessions only
|
||||
for _, sess := range sessions {
|
||||
if sess.Status == store.SessionStatusRunning ||
|
||||
sess.Status == store.SessionStatusWaitingInput ||
|
||||
sess.Status == store.SessionStatusStarting {
|
||||
|
||||
mockStore.EXPECT().
|
||||
UpdateSession(gomock.Any(), sess.ID, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, id string, update store.SessionUpdate) error {
|
||||
// Verify the update
|
||||
if update.Status == nil || *update.Status != store.SessionStatusFailed {
|
||||
t.Errorf("expected status update to failed, got %v", update.Status)
|
||||
}
|
||||
if update.ErrorMessage == nil || *update.ErrorMessage != "daemon restarted while session was active" {
|
||||
t.Errorf("expected error message about daemon restart, got %v", update.ErrorMessage)
|
||||
}
|
||||
if update.CompletedAt == nil {
|
||||
t.Error("expected CompletedAt to be set")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Create daemon with mock store
|
||||
d := &Daemon{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
// Call markOrphanedSessionsAsFailed
|
||||
err := d.markOrphanedSessionsAsFailed(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Expectations are verified by gomock
|
||||
}
|
||||
|
||||
func TestDaemon_MarkOrphanedSessions_NoStore(t *testing.T) {
|
||||
// Test that it handles nil store gracefully
|
||||
d := &Daemon{
|
||||
store: nil,
|
||||
}
|
||||
|
||||
err := d.markOrphanedSessionsAsFailed(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error with nil store, got: %v", err)
|
||||
}
|
||||
}
|
||||
288
hld/daemon/daemon_session_state_integration_test.go
Normal file
288
hld/daemon/daemon_session_state_integration_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
//go:build integration
|
||||
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/approval"
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/client"
|
||||
"github.com/humanlayer/humanlayer/hld/config"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
"github.com/humanlayer/humanlayer/hld/session"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
humanlayer "github.com/humanlayer/humanlayer/humanlayer-go"
|
||||
)
|
||||
|
||||
// TestSessionStateTransitionsIntegration tests the full flow of session state changes
|
||||
// when approvals are created and resolved
|
||||
func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Create temporary socket path for test
|
||||
socketPath := testutil.SocketPath(t, "session-state")
|
||||
|
||||
// Create in-memory store
|
||||
testStore, err := store.NewSQLiteStore(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test store: %v", err)
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Create event bus
|
||||
eventBus := bus.NewEventBus()
|
||||
|
||||
// Create mock API client
|
||||
mockClient := &mockSessionStateAPIClient{
|
||||
functionCalls: make(map[string]*humanlayer.FunctionCall),
|
||||
}
|
||||
|
||||
// Create real approval components
|
||||
approvalStore := approval.NewMemoryStore()
|
||||
poller := approval.NewPoller(mockClient, approvalStore, testStore, 50*time.Millisecond, eventBus)
|
||||
|
||||
// Create approval manager
|
||||
approvalManager := &approval.DefaultManager{
|
||||
Client: mockClient,
|
||||
Store: approvalStore,
|
||||
Poller: poller,
|
||||
EventBus: eventBus,
|
||||
ConversationStore: testStore,
|
||||
}
|
||||
|
||||
// Create session manager
|
||||
sessionManager, err := session.NewManager(eventBus, testStore)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create daemon
|
||||
d := &Daemon{
|
||||
config: &config.Config{
|
||||
SocketPath: socketPath,
|
||||
APIKey: "test-key",
|
||||
},
|
||||
socketPath: socketPath,
|
||||
sessions: sessionManager,
|
||||
approvals: approvalManager,
|
||||
eventBus: eventBus,
|
||||
store: testStore,
|
||||
}
|
||||
|
||||
// Start daemon
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
doneCh := make(chan error)
|
||||
go func() {
|
||||
doneCh <- d.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give daemon time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create client
|
||||
c, err := client.New(socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Test scenario: Create session with tool call that needs approval
|
||||
t.Run("session_state_transitions", func(t *testing.T) {
|
||||
// 1. Create a session manually in the database
|
||||
sessionID := "test-session-001"
|
||||
runID := "test-run-001"
|
||||
claudeSessionID := "claude-sess-001"
|
||||
|
||||
ctx := context.Background()
|
||||
session := &store.Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
Query: "Test query",
|
||||
Status: store.SessionStatusRunning,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
}
|
||||
if err := testStore.CreateSession(ctx, session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// 2. Add a tool call that needs approval
|
||||
toolCall := &store.ConversationEvent{
|
||||
SessionID: sessionID,
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
Sequence: 1,
|
||||
EventType: store.EventTypeToolCall,
|
||||
ToolID: "tool-001",
|
||||
ToolName: "dangerous_function",
|
||||
ToolInputJSON: `{"action": "delete_all"}`,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := testStore.AddConversationEvent(ctx, toolCall); err != nil {
|
||||
t.Fatalf("failed to add tool call: %v", err)
|
||||
}
|
||||
|
||||
// 3. Simulate an approval coming from HumanLayer API
|
||||
approvalID := "approval-001"
|
||||
mockClient.AddFunctionCall(humanlayer.FunctionCall{
|
||||
CallID: approvalID,
|
||||
RunID: runID,
|
||||
Spec: humanlayer.FunctionCallSpec{
|
||||
Fn: "dangerous_function",
|
||||
Kwargs: map[string]interface{}{
|
||||
"action": "delete_all",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 4. Wait for poller to pick up the approval and correlate it
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 5. Check that session status changed to waiting_input
|
||||
updatedSession, err := testStore.GetSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get session: %v", err)
|
||||
}
|
||||
if updatedSession.Status != store.SessionStatusWaitingInput {
|
||||
t.Errorf("expected session status to be waiting_input, got %s", updatedSession.Status)
|
||||
}
|
||||
|
||||
// 6. Check that approval was correlated
|
||||
conversation, err := testStore.GetConversation(ctx, claudeSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get conversation: %v", err)
|
||||
}
|
||||
|
||||
var correlatedEvent *store.ConversationEvent
|
||||
for _, event := range conversation {
|
||||
if event.EventType == store.EventTypeToolCall && event.ToolName == "dangerous_function" {
|
||||
correlatedEvent = event
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if correlatedEvent == nil {
|
||||
t.Fatal("tool call event not found")
|
||||
}
|
||||
if correlatedEvent.ApprovalStatus != store.ApprovalStatusPending {
|
||||
t.Errorf("expected approval status to be pending, got %s", correlatedEvent.ApprovalStatus)
|
||||
}
|
||||
if correlatedEvent.ApprovalID != approvalID {
|
||||
t.Errorf("expected approval ID to be %s, got %s", approvalID, correlatedEvent.ApprovalID)
|
||||
}
|
||||
|
||||
// 7. Approve the function call via client
|
||||
if err := c.SendDecision(approvalID, "function_call", "approve", "Approved for testing"); err != nil {
|
||||
t.Fatalf("failed to send approval: %v", err)
|
||||
}
|
||||
|
||||
// Give time for status update
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 8. Check that session status changed back to running
|
||||
finalSession, err := testStore.GetSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get final session: %v", err)
|
||||
}
|
||||
if finalSession.Status != store.SessionStatusRunning {
|
||||
t.Errorf("expected session status to be running after approval, got %s", finalSession.Status)
|
||||
}
|
||||
|
||||
// 9. Check that approval status was updated
|
||||
finalConversation, err := testStore.GetConversation(ctx, claudeSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get final conversation: %v", err)
|
||||
}
|
||||
|
||||
var finalEvent *store.ConversationEvent
|
||||
for _, event := range finalConversation {
|
||||
if event.EventType == store.EventTypeToolCall && event.ToolName == "dangerous_function" {
|
||||
finalEvent = event
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if finalEvent == nil {
|
||||
t.Fatal("tool call event not found in final conversation")
|
||||
}
|
||||
if finalEvent.ApprovalStatus != store.ApprovalStatusApproved {
|
||||
t.Errorf("expected approval status to be approved, got %s", finalEvent.ApprovalStatus)
|
||||
}
|
||||
})
|
||||
|
||||
// Cleanup
|
||||
cancel()
|
||||
select {
|
||||
case err := <-doneCh:
|
||||
if err != nil {
|
||||
t.Errorf("daemon returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("daemon did not shut down in time")
|
||||
}
|
||||
}
|
||||
|
||||
// Mock API client for session state testing
|
||||
type mockSessionStateAPIClient struct {
|
||||
functionCalls map[string]*humanlayer.FunctionCall
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) AddFunctionCall(fc humanlayer.FunctionCall) {
|
||||
m.functionCalls[fc.CallID] = &fc
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) GetPendingFunctionCalls(ctx context.Context) ([]humanlayer.FunctionCall, error) {
|
||||
var result []humanlayer.FunctionCall
|
||||
for _, fc := range m.functionCalls {
|
||||
if fc.Status == nil || fc.Status.RespondedAt == nil {
|
||||
result = append(result, *fc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) GetPendingHumanContacts(ctx context.Context) ([]humanlayer.HumanContact, error) {
|
||||
return []humanlayer.HumanContact{}, nil
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) ApproveFunctionCall(ctx context.Context, callID string, comment string) error {
|
||||
if fc, ok := m.functionCalls[callID]; ok {
|
||||
if fc.Status == nil {
|
||||
fc.Status = &humanlayer.FunctionCallStatus{}
|
||||
}
|
||||
now := humanlayer.CustomTime{Time: time.Now()}
|
||||
fc.Status.RespondedAt = &now
|
||||
approved := true
|
||||
fc.Status.Approved = &approved
|
||||
fc.Status.Comment = comment
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("function call not found: %s", callID)
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) DenyFunctionCall(ctx context.Context, callID string, reason string) error {
|
||||
if fc, ok := m.functionCalls[callID]; ok {
|
||||
if fc.Status == nil {
|
||||
fc.Status = &humanlayer.FunctionCallStatus{}
|
||||
}
|
||||
now := humanlayer.CustomTime{Time: time.Now()}
|
||||
fc.Status.RespondedAt = &now
|
||||
approved := false
|
||||
fc.Status.Approved = &approved
|
||||
fc.Status.Comment = reason
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("function call not found: %s", callID)
|
||||
}
|
||||
|
||||
func (m *mockSessionStateAPIClient) RespondToHumanContact(ctx context.Context, callID string, response string) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
@@ -20,10 +20,12 @@ type SQLiteStore struct {
|
||||
|
||||
// NewSQLiteStore creates a new SQLite-backed store
|
||||
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
|
||||
// Ensure directory exists
|
||||
dbDir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dbDir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
// Ensure directory exists (skip for in-memory databases)
|
||||
if dbPath != ":memory:" {
|
||||
dbDir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dbDir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Open database
|
||||
@@ -308,6 +310,64 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetSessionByRunID retrieves a session by its run_id
|
||||
func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Session, error) {
|
||||
query := `
|
||||
SELECT id, run_id, claude_session_id, parent_session_id,
|
||||
query, model, working_dir, max_turns, system_prompt, custom_instructions,
|
||||
status, created_at, last_activity_at, completed_at,
|
||||
cost_usd, total_tokens, duration_ms, error_message
|
||||
FROM sessions
|
||||
WHERE run_id = ?
|
||||
`
|
||||
|
||||
var session Session
|
||||
var claudeSessionID, parentSessionID, model, workingDir, systemPrompt, customInstructions sql.NullString
|
||||
var completedAt sql.NullTime
|
||||
var costUSD sql.NullFloat64
|
||||
var totalTokens, durationMS sql.NullInt64
|
||||
var errorMessage sql.NullString
|
||||
|
||||
err := s.db.QueryRowContext(ctx, query, runID).Scan(
|
||||
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
|
||||
&session.Query, &model, &workingDir, &session.MaxTurns,
|
||||
&systemPrompt, &customInstructions,
|
||||
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
|
||||
&costUSD, &totalTokens, &durationMS, &errorMessage,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil // No session found
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by run_id: %w", err)
|
||||
}
|
||||
|
||||
// Convert nullable fields
|
||||
session.ClaudeSessionID = claudeSessionID.String
|
||||
session.ParentSessionID = parentSessionID.String
|
||||
session.Model = model.String
|
||||
session.WorkingDir = workingDir.String
|
||||
session.SystemPrompt = systemPrompt.String
|
||||
session.CustomInstructions = customInstructions.String
|
||||
session.ErrorMessage = errorMessage.String
|
||||
if completedAt.Valid {
|
||||
session.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if costUSD.Valid {
|
||||
session.CostUSD = &costUSD.Float64
|
||||
}
|
||||
if totalTokens.Valid {
|
||||
tokens := int(totalTokens.Int64)
|
||||
session.TotalTokens = &tokens
|
||||
}
|
||||
if durationMS.Valid {
|
||||
duration := int(durationMS.Int64)
|
||||
session.DurationMS = &duration
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ListSessions retrieves all sessions
|
||||
func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
|
||||
query := `
|
||||
|
||||
@@ -13,6 +13,7 @@ type ConversationStore interface {
|
||||
CreateSession(ctx context.Context, session *Session) error
|
||||
UpdateSession(ctx context.Context, sessionID string, updates SessionUpdate) error
|
||||
GetSession(ctx context.Context, sessionID string) (*Session, error)
|
||||
GetSessionByRunID(ctx context.Context, runID string) (*Session, error)
|
||||
ListSessions(ctx context.Context) ([]*Session, error)
|
||||
|
||||
// Conversation operations
|
||||
|
||||
Reference in New Issue
Block a user