fix resumed sessions to inherit all configuration from parent (#253)

* Add missing configuration fields to sessions table

- Add permission_prompt_tool, append_system_prompt, allowed_tools,
  and disallowed_tools fields to Session struct
- Create database migration to add new columns to sessions table
- Update CreateSession, GetSession, GetSessionByRunID, and ListSessions
  to handle new fields
- Serialize tool arrays as JSON for database storage

* Implement full configuration inheritance for resumed sessions

- Inherit all configuration fields from parent session in ContinueSession
  (except MaxTurns which is intentionally not inherited)
- Retrieve and inherit MCP server configurations from parent
- Add comprehensive tests for field inheritance, MCP inheritance, and
  override behavior
- Update existing tests to handle new mock expectations

* Remove manual inheritance workaround from TUI

Now that the backend properly inherits all configuration from parent
sessions, remove the manual workaround that was storing and applying
parentModel and parentWorkingDir fields.

* Fix database migration ordering and versioning

Move migration application after schema initialization to ensure proper
execution order for both new and existing databases. Update migration
version to 3 and ensure new databases start with the latest schema
version to avoid unnecessary migrations.

* Fix MCP server retrieval ordering for consistent test results

Add ORDER BY clause to GetMCPServers query to ensure deterministic
ordering when retrieving MCP servers from the database. This fixes
the InheritsMCPServers test failure in CI where servers were being
returned in non-deterministic order.

* Fix MCP server ordering for consistent test results

Ensure deterministic ordering when converting MCP servers from map to slice
by sorting server names before iteration. This prevents test flakiness
caused by map iteration order variations.

* Fix integration test failures and improve test infrastructure

- Enable integration tests in standard test suite by default
- Add quiet mode support for integration tests to match unit test output
- Fix scanner usage and connection handling in session tests
- Add config override to prevent loading user configuration in tests
- Filter approval poller events in subscription tests to reduce noise
- Use in-memory database for test isolation

* Fix race conditions and improve error handling

- Fix approval status race condition: only update from pending to resolved,
  never overwrite approved/denied status
- Return error for non-existent sessions instead of empty result
- Fix listener double-close issue by tracking close state
- Handle context cancellation properly in monitorSession to prevent
  'database is closed' errors during shutdown
- Update test expectations to match new error handling behavior

* Fix CI test failure in TestSessionStateTransitionsIntegration

Use temporary database file instead of :memory: to ensure all connections
access the same database. SQLite in-memory databases are unique per connection,
which was causing 'no such table' errors when different components tried to
access the database.

* formatting
This commit is contained in:
Allison Durham
2025-06-27 15:26:50 -07:00
committed by GitHub
parent 1a20585823
commit 37d13c4727
15 changed files with 968 additions and 158 deletions

View File

@@ -66,8 +66,8 @@ test-hlyr: ## Test hlyr CLI tool
@$(MAKE) -C hlyr test VERBOSE=$(VERBOSE)
.PHONY: test-hld
test-hld: ## Test hld daemon (unit tests only)
@$(MAKE) -C hld test-unit VERBOSE=$(VERBOSE)
test-hld: ## Test hld daemon (unit and integration tests)
@$(MAKE) -C hld test VERBOSE=$(VERBOSE)
.PHONY: test-hld-integration
test-hld-integration: ## Test hld daemon (including integration tests)

View File

@@ -5,13 +5,32 @@ build:
go build -o hld ./cmd/hld
# Run all tests
test: test-unit test-integration
test:
@if [ -n "$$VERBOSE" ]; then \
$(MAKE) test-unit test-integration; \
else \
$(MAKE) test-quiet; \
fi
# Run all tests with quiet output
test-quiet:
@. ../hack/run_silent.sh && print_header "hld" "Daemon tests"
@$(MAKE) test-unit-quiet
@$(MAKE) test-integration-quiet
# Base test-unit target overridden below
# Run integration tests (requires build tag)
test-integration: build
CGO_LDFLAGS="-Wl,-w" go test -v -tags=integration -run Integration ./daemon/...
@if [ -n "$$VERBOSE" ]; then \
CGO_LDFLAGS="-Wl,-w" go test -v -tags=integration -run Integration ./daemon/...; \
else \
$(MAKE) test-integration-quiet; \
fi
# Run integration tests with quiet output
test-integration-quiet: build
@. ../hack/run_silent.sh && run_silent_with_test_count "Integration tests passed" "CGO_LDFLAGS=\"-Wl,-w\" go test -json -tags=integration -run Integration ./daemon/..." "go"
# Run tests with race detection
test-race:

View File

@@ -134,12 +134,17 @@ func (d *Daemon) Run(ctx context.Context) error {
return fmt.Errorf("failed to set socket permissions: %w", err)
}
// Track if listener was already closed
listenerClosed := &struct{ closed bool }{}
// Ensure cleanup on exit
defer func() {
if err := listener.Close(); err != nil {
slog.Warn("failed to close listener", "error", err)
if !listenerClosed.closed {
if err := listener.Close(); err != nil {
slog.Warn("failed to close listener", "error", err)
}
}
if err := os.Remove(d.socketPath); err != nil {
if err := os.Remove(d.socketPath); err != nil && !os.IsNotExist(err) {
slog.Warn("failed to remove socket file", "path", d.socketPath, "error", err)
}
if d.store != nil {
@@ -196,6 +201,7 @@ func (d *Daemon) Run(ctx context.Context) error {
if err := listener.Close(); err != nil {
slog.Warn("error closing listener during shutdown", "error", err)
}
listenerClosed.closed = true
return nil
}

View File

@@ -231,13 +231,17 @@ func TestIntegrationContinueSession(t *testing.T) {
}
}
// Continue the session
// Continue the session with some overrides
req := rpc.ContinueSessionRequest{
SessionID: parentSessionID,
Query: "follow up question",
SystemPrompt: "You are helpful",
CustomInstructions: "Be concise",
MaxTurns: 3,
SessionID: parentSessionID,
Query: "follow up question",
SystemPrompt: "You are helpful",
AppendSystemPrompt: "Always be polite",
CustomInstructions: "Be concise",
PermissionPromptTool: "hlyr",
AllowedTools: []string{"read", "write"},
DisallowedTools: []string{"delete"},
MaxTurns: 3,
}
result, err := sendRPC(t, "continueSession", req)
@@ -284,9 +288,32 @@ func TestIntegrationContinueSession(t *testing.T) {
if newSession.SystemPrompt != "You are helpful" {
t.Errorf("Expected system prompt override, got %s", newSession.SystemPrompt)
}
if newSession.AppendSystemPrompt != "Always be polite" {
t.Errorf("Expected append system prompt, got %s", newSession.AppendSystemPrompt)
}
if newSession.CustomInstructions != "Be concise" {
t.Errorf("Expected custom instructions override, got %s", newSession.CustomInstructions)
}
if newSession.PermissionPromptTool != "hlyr" {
t.Errorf("Expected permission prompt tool, got %s", newSession.PermissionPromptTool)
}
// Check allowed tools
var allowedTools []string
if err := json.Unmarshal([]byte(newSession.AllowedTools), &allowedTools); err == nil {
if len(allowedTools) != 2 || allowedTools[0] != "read" || allowedTools[1] != "write" {
t.Errorf("Expected allowed tools [read, write], got %v", allowedTools)
}
}
// Check disallowed tools
var disallowedTools []string
if err := json.Unmarshal([]byte(newSession.DisallowedTools), &disallowedTools); err == nil {
if len(disallowedTools) != 1 || disallowedTools[0] != "delete" {
t.Errorf("Expected disallowed tools [delete], got %v", disallowedTools)
}
}
if newSession.MaxTurns != 3 {
t.Errorf("Expected max turns 3, got %d", newSession.MaxTurns)
}

View File

@@ -193,7 +193,9 @@ func TestGetConversationIntegration(t *testing.T) {
// because it needs to look up the claude_session_id first
_, err := daemonClient.GetConversation("nonexistent-session")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get conversation")
if err != nil {
assert.Contains(t, err.Error(), "failed to get conversation")
}
})
t.Run("GetSessionState for nonexistent session", func(t *testing.T) {

View File

@@ -4,11 +4,13 @@
package daemon
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"testing"
"time"
@@ -24,6 +26,23 @@ func TestSessionLaunchIntegration(t *testing.T) {
// Set environment for test
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
defer os.Unsetenv("HUMANLAYER_DAEMON_SOCKET")
// Disable API key to prevent approval manager issues in tests
os.Setenv("HUMANLAYER_API_KEY", "")
defer os.Unsetenv("HUMANLAYER_API_KEY")
// Use a temporary config directory to avoid loading user's config
tempDir := t.TempDir()
os.Setenv("XDG_CONFIG_HOME", tempDir)
defer os.Unsetenv("XDG_CONFIG_HOME")
// Create an empty config file to override any existing config
configDir := filepath.Join(tempDir, "humanlayer")
if err := os.MkdirAll(configDir, 0755); err != nil {
t.Fatalf("Failed to create config dir: %v", err)
}
configFile := filepath.Join(configDir, "humanlayer.json")
if err := os.WriteFile(configFile, []byte(`{}`), 0644); err != nil {
t.Fatalf("Failed to create empty config file: %v", err)
}
// Create and start daemon
daemon, err := New()
@@ -57,15 +76,14 @@ func TestSessionLaunchIntegration(t *testing.T) {
t.Fatalf("Daemon socket not ready after 5 seconds: %v", err)
}
// Connect to daemon
conn, err := net.Dial("unix", socketPath)
if err != nil {
t.Fatalf("Failed to connect to daemon: %v", err)
}
defer conn.Close()
// Test launching a session
t.Run("LaunchSession", func(t *testing.T) {
// Connect to daemon
conn, err := net.Dial("unix", socketPath)
if err != nil {
t.Fatalf("Failed to connect to daemon: %v", err)
}
defer conn.Close()
// Check if claude binary exists
if _, err := os.Stat("/usr/bin/claude"); os.IsNotExist(err) {
t.Skip("Claude binary not found, skipping launch test")
@@ -98,14 +116,17 @@ func TestSessionLaunchIntegration(t *testing.T) {
}
// Read response
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large responses
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
t.Fatalf("Scanner error: %v", err)
}
t.Fatal("Failed to read response")
}
var resp map[string]interface{}
if err := json.Unmarshal(buf[:n], &resp); err != nil {
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
@@ -140,6 +161,12 @@ func TestSessionLaunchIntegration(t *testing.T) {
// Test listing sessions
t.Run("ListSessions", func(t *testing.T) {
// Connect to daemon
conn, err := net.Dial("unix", socketPath)
if err != nil {
t.Fatalf("Failed to connect to daemon: %v", err)
}
defer conn.Close()
// Send ListSessions request
reqData, _ := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
@@ -153,14 +180,17 @@ func TestSessionLaunchIntegration(t *testing.T) {
}
// Read response
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large responses
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
t.Fatalf("Scanner error: %v", err)
}
t.Fatal("Failed to read response")
}
var resp map[string]interface{}
if err := json.Unmarshal(buf[:n], &resp); err != nil {
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
@@ -218,6 +248,23 @@ func TestConcurrentSessions(t *testing.T) {
// Set environment for test
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
defer os.Unsetenv("HUMANLAYER_DAEMON_SOCKET")
// Disable API key to prevent approval manager issues in tests
os.Setenv("HUMANLAYER_API_KEY", "")
defer os.Unsetenv("HUMANLAYER_API_KEY")
// Use a temporary config directory to avoid loading user's config
tempDir := t.TempDir()
os.Setenv("XDG_CONFIG_HOME", tempDir)
defer os.Unsetenv("XDG_CONFIG_HOME")
// Create an empty config file to override any existing config
configDir := filepath.Join(tempDir, "humanlayer")
if err := os.MkdirAll(configDir, 0755); err != nil {
t.Fatalf("Failed to create config dir: %v", err)
}
configFile := filepath.Join(configDir, "humanlayer.json")
if err := os.WriteFile(configFile, []byte(`{}`), 0644); err != nil {
t.Fatalf("Failed to create empty config file: %v", err)
}
// Create and start daemon
daemon, err := New()
@@ -280,15 +327,19 @@ func TestConcurrentSessions(t *testing.T) {
}
// Read response
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
results <- fmt.Errorf("session %d: failed to read: %w", sessionNum, err)
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
results <- fmt.Errorf("session %d: scanner error: %w", sessionNum, err)
} else {
results <- fmt.Errorf("session %d: failed to read response", sessionNum)
}
return
}
var resp map[string]interface{}
if err := json.Unmarshal(buf[:n], &resp); err != nil {
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
results <- fmt.Errorf("session %d: failed to parse: %w", sessionNum, err)
return
}
@@ -332,14 +383,17 @@ func TestConcurrentSessions(t *testing.T) {
t.Fatalf("Failed to send list request: %v", err)
}
buf := make([]byte, 8192)
n, err := conn.Read(buf)
if err != nil {
t.Fatalf("Failed to read list response: %v", err)
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large responses
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
t.Fatalf("Failed to read list response: %v", err)
}
t.Fatal("Failed to read list response: no data")
}
var resp map[string]interface{}
if err := json.Unmarshal(buf[:n], &resp); err != nil {
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
t.Fatalf("Failed to parse list response: %v", err)
}

View File

@@ -5,6 +5,7 @@ package daemon
import (
"context"
"fmt"
"path/filepath"
"testing"
"time"
@@ -28,13 +29,22 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
// Create temporary socket path for test
socketPath := testutil.SocketPath(t, "session-state")
// Create in-memory store
testStore, err := store.NewSQLiteStore(":memory:")
// Use a temporary database file instead of :memory: to ensure all connections
// access the same database (in-memory databases are unique per connection)
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")
// Create the store
testStore, err := store.NewSQLiteStore(dbPath)
if err != nil {
t.Fatalf("failed to create test store: %v", err)
}
defer testStore.Close()
// Set environment variables to ensure consistent test behavior
t.Setenv("HUMANLAYER_DATABASE_PATH", dbPath)
t.Setenv("HUMANLAYER_API_KEY", "test-key")
// Create event bus
eventBus := bus.NewEventBus()

View File

@@ -18,6 +18,10 @@ func TestDaemonSubscriptionIntegration(t *testing.T) {
socketPath := testutil.CreateTestSocket(t)
t.Setenv("HUMANLAYER_SOCKET_PATH", socketPath)
t.Setenv("HUMANLAYER_LOG_LEVEL", "error")
// Explicitly disable API key to prevent approval polling
t.Setenv("HUMANLAYER_API_KEY", "")
// Use in-memory database for tests
t.Setenv("HUMANLAYER_DATABASE_PATH", ":memory:")
// Create and start daemon
daemon, err := New()
@@ -131,6 +135,14 @@ func TestDaemonSubscriptionIntegration(t *testing.T) {
t.Fatalf("Client 2 failed to subscribe: %v", err)
}
// Clear any initial events from approval poller
select {
case <-eventChan1:
// Discard any initial events
case <-time.After(50 * time.Millisecond):
// No initial events, continue
}
// Publish different events
daemon.eventBus.Publish(bus.Event{
Type: bus.EventNewApproval,
@@ -166,6 +178,16 @@ func TestDaemonSubscriptionIntegration(t *testing.T) {
select {
case notification, ok := <-eventChan1:
if ok && notification.Event.Type != "" {
// Check if this is an approval poller event (has count/total/type fields)
data := notification.Event.Data
if _, hasCount := data["count"]; hasCount {
if _, hasTotal := data["total"]; hasTotal {
if _, hasType := data["type"]; hasType {
// This is an approval poller event, ignore it
break
}
}
}
t.Errorf("Client 1 unexpectedly received event - Type: %q, Data: %+v", notification.Event.Type, notification.Event.Data)
}
// If channel is closed (!ok) or empty event, that's fine during cleanup

View File

@@ -0,0 +1,471 @@
package session
import (
"context"
"encoding/json"
"testing"
"time"
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
"github.com/humanlayer/humanlayer/hld/bus"
"github.com/humanlayer/humanlayer/hld/store"
)
func TestContinueSessionInheritance(t *testing.T) {
ctx := context.Background()
// Create test components
eventBus := bus.NewEventBus()
sqliteStore, err := store.NewSQLiteStore(":memory:")
if err != nil {
t.Fatalf("Failed to create store: %v", err)
}
defer func() { _ = sqliteStore.Close() }()
manager, err := NewManager(eventBus, sqliteStore)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
t.Run("InheritsAllConfigurationFields", func(t *testing.T) {
// Create parent session with full configuration
parentSessionID := "parent-full-config"
parentSession := &store.Session{
ID: parentSessionID,
RunID: "run-parent",
ClaudeSessionID: "claude-parent",
Status: store.SessionStatusCompleted,
Query: "original query",
Model: "claude-3-opus-20240229",
WorkingDir: "/tmp/test",
MaxTurns: 10,
SystemPrompt: "You are a helpful assistant",
AppendSystemPrompt: "Be concise",
CustomInstructions: "Follow best practices",
PermissionPromptTool: "hlyr",
AllowedTools: `["tool1", "tool2"]`,
DisallowedTools: `["tool3"]`,
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
CompletedAt: &time.Time{},
}
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
t.Fatalf("Failed to create parent session: %v", err)
}
// Continue session with minimal config (only required fields)
req := ContinueSessionConfig{
ParentSessionID: parentSessionID,
Query: "follow up question",
}
// This will fail because Claude binary doesn't exist in test env,
// but we can still check that the config was built correctly
_, _ = manager.ContinueSession(ctx, req)
// We expect it to fail at launch, but let's check the session was created with inherited config
sessions, err := sqliteStore.ListSessions(ctx)
if err != nil {
t.Fatalf("Failed to list sessions: %v", err)
}
// Should have parent and child sessions
if len(sessions) < 2 {
t.Fatalf("Expected at least 2 sessions, got %d", len(sessions))
}
// Find the child session (most recent)
var childSession *store.Session
for _, s := range sessions {
if s.ParentSessionID == parentSessionID {
childSession = s
break
}
}
if childSession == nil {
t.Fatal("Child session not found")
}
// Verify all fields were inherited
if childSession.Model != parentSession.Model {
t.Errorf("Model not inherited: got %s, want %s", childSession.Model, parentSession.Model)
}
if childSession.WorkingDir != parentSession.WorkingDir {
t.Errorf("WorkingDir not inherited: got %s, want %s", childSession.WorkingDir, parentSession.WorkingDir)
}
if childSession.SystemPrompt != parentSession.SystemPrompt {
t.Errorf("SystemPrompt not inherited: got %s, want %s", childSession.SystemPrompt, parentSession.SystemPrompt)
}
if childSession.AppendSystemPrompt != parentSession.AppendSystemPrompt {
t.Errorf("AppendSystemPrompt not inherited: got %s, want %s", childSession.AppendSystemPrompt, parentSession.AppendSystemPrompt)
}
if childSession.CustomInstructions != parentSession.CustomInstructions {
t.Errorf("CustomInstructions not inherited: got %s, want %s", childSession.CustomInstructions, parentSession.CustomInstructions)
}
if childSession.PermissionPromptTool != parentSession.PermissionPromptTool {
t.Errorf("PermissionPromptTool not inherited: got %s, want %s", childSession.PermissionPromptTool, parentSession.PermissionPromptTool)
}
// Compare allowed tools (deserialize to compare content, not formatting)
var childAllowed, parentAllowed []string
if err := json.Unmarshal([]byte(childSession.AllowedTools), &childAllowed); err != nil {
t.Fatalf("Failed to unmarshal child allowed tools: %v", err)
}
if err := json.Unmarshal([]byte(parentSession.AllowedTools), &parentAllowed); err != nil {
t.Fatalf("Failed to unmarshal parent allowed tools: %v", err)
}
if len(childAllowed) != len(parentAllowed) {
t.Errorf("AllowedTools length mismatch: got %d, want %d", len(childAllowed), len(parentAllowed))
} else {
for i, tool := range childAllowed {
if tool != parentAllowed[i] {
t.Errorf("AllowedTools[%d] not inherited: got %s, want %s", i, tool, parentAllowed[i])
}
}
}
// Compare disallowed tools
var childDisallowed, parentDisallowed []string
if err := json.Unmarshal([]byte(childSession.DisallowedTools), &childDisallowed); err != nil {
t.Fatalf("Failed to unmarshal child disallowed tools: %v", err)
}
if err := json.Unmarshal([]byte(parentSession.DisallowedTools), &parentDisallowed); err != nil {
t.Fatalf("Failed to unmarshal parent disallowed tools: %v", err)
}
if len(childDisallowed) != len(parentDisallowed) {
t.Errorf("DisallowedTools length mismatch: got %d, want %d", len(childDisallowed), len(parentDisallowed))
} else {
for i, tool := range childDisallowed {
if tool != parentDisallowed[i] {
t.Errorf("DisallowedTools[%d] not inherited: got %s, want %s", i, tool, parentDisallowed[i])
}
}
}
// MaxTurns should NOT be inherited (as per spec)
if childSession.MaxTurns == parentSession.MaxTurns {
t.Error("MaxTurns should not be inherited")
}
})
t.Run("InheritsMCPServers", func(t *testing.T) {
// Create parent session
parentSessionID := "parent-mcp"
parentSession := &store.Session{
ID: parentSessionID,
RunID: "run-mcp",
ClaudeSessionID: "claude-mcp",
Status: store.SessionStatusCompleted,
Query: "mcp query",
Model: "claude-3-opus-20240229",
WorkingDir: "/tmp/test",
SystemPrompt: "Test prompt",
PermissionPromptTool: "hlyr",
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
CompletedAt: &time.Time{},
}
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
t.Fatalf("Failed to create parent session: %v", err)
}
// Store MCP servers for parent
mcpServers := []store.MCPServer{
{
SessionID: parentSessionID,
Name: "test-server-1",
Command: "node",
ArgsJSON: `["server1.js", "--port", "3000"]`,
EnvJSON: `{"NODE_ENV": "test", "API_KEY": "secret"}`,
},
{
SessionID: parentSessionID,
Name: "test-server-2",
Command: "python",
ArgsJSON: `["server2.py"]`,
EnvJSON: `{"PYTHONPATH": "/usr/lib"}`,
},
}
if err := sqliteStore.StoreMCPServers(ctx, parentSessionID, mcpServers); err != nil {
t.Fatalf("Failed to store MCP servers: %v", err)
}
// Continue session
req := ContinueSessionConfig{
ParentSessionID: parentSessionID,
Query: "mcp follow up",
}
_, _ = manager.ContinueSession(ctx, req)
// Expected to fail due to missing Claude binary
// Find the child session
sessions, err := sqliteStore.ListSessions(ctx)
if err != nil {
t.Fatalf("Failed to list sessions: %v", err)
}
var childSession *store.Session
for _, s := range sessions {
if s.ParentSessionID == parentSessionID {
childSession = s
break
}
}
if childSession == nil {
t.Fatal("Child session not found")
}
// Get MCP servers for child session
childMCPServers, err := sqliteStore.GetMCPServers(ctx, childSession.ID)
if err != nil {
t.Fatalf("Failed to get child MCP servers: %v", err)
}
// Should have inherited the MCP servers
if len(childMCPServers) != len(mcpServers) {
t.Errorf("MCP servers not inherited: got %d, want %d", len(childMCPServers), len(mcpServers))
}
// Verify server details (accounting for HUMANLAYER_RUN_ID being added)
for i, server := range childMCPServers {
if server.Name != mcpServers[i].Name {
t.Errorf("MCP server %d name mismatch: got %s, want %s", i, server.Name, mcpServers[i].Name)
}
if server.Command != mcpServers[i].Command {
t.Errorf("MCP server %d command mismatch: got %s, want %s", i, server.Command, mcpServers[i].Command)
}
// Compare args (deserialize to compare content)
var childArgs, parentArgs []string
if err := json.Unmarshal([]byte(server.ArgsJSON), &childArgs); err != nil {
t.Fatalf("Failed to unmarshal child args: %v", err)
}
if err := json.Unmarshal([]byte(mcpServers[i].ArgsJSON), &parentArgs); err != nil {
t.Fatalf("Failed to unmarshal parent args: %v", err)
}
if len(childArgs) != len(parentArgs) {
t.Errorf("MCP server %d args length mismatch", i)
} else {
for j, arg := range childArgs {
if arg != parentArgs[j] {
t.Errorf("MCP server %d arg[%d] mismatch: got %s, want %s", i, j, arg, parentArgs[j])
}
}
}
// Compare env (deserialize and check that parent env is subset of child env)
var childEnv, parentEnv map[string]string
if err := json.Unmarshal([]byte(server.EnvJSON), &childEnv); err != nil {
t.Fatalf("Failed to unmarshal child env: %v", err)
}
if err := json.Unmarshal([]byte(mcpServers[i].EnvJSON), &parentEnv); err != nil {
t.Fatalf("Failed to unmarshal parent env: %v", err)
}
// Child should have all parent env vars plus HUMANLAYER_RUN_ID
for key, val := range parentEnv {
if childEnv[key] != val {
t.Errorf("MCP server %d env[%s] mismatch: got %s, want %s", i, key, childEnv[key], val)
}
}
// Should have HUMANLAYER_RUN_ID added
if _, ok := childEnv["HUMANLAYER_RUN_ID"]; !ok {
t.Errorf("MCP server %d missing HUMANLAYER_RUN_ID in env", i)
}
}
})
t.Run("OverridesWorkCorrectly", func(t *testing.T) {
// Create parent session
parentSessionID := "parent-override"
parentSession := &store.Session{
ID: parentSessionID,
RunID: "run-override",
ClaudeSessionID: "claude-override",
Status: store.SessionStatusCompleted,
Query: "original",
Model: "claude-3-opus-20240229",
WorkingDir: "/tmp/test",
SystemPrompt: "Original system prompt",
AppendSystemPrompt: "Original append",
CustomInstructions: "Original instructions",
PermissionPromptTool: "original-tool",
AllowedTools: `["original1", "original2"]`,
DisallowedTools: `["original3"]`,
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
CompletedAt: &time.Time{},
}
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
t.Fatalf("Failed to create parent session: %v", err)
}
// Continue with overrides
req := ContinueSessionConfig{
ParentSessionID: parentSessionID,
Query: "override query",
SystemPrompt: "Override system prompt",
AppendSystemPrompt: "Override append",
CustomInstructions: "Override instructions",
PermissionPromptTool: "override-tool",
AllowedTools: []string{"override1", "override2", "override3"},
DisallowedTools: []string{"override4", "override5"},
MaxTurns: 5,
}
_, _ = manager.ContinueSession(ctx, req)
// Expected to fail due to missing Claude binary
// Find the child session
sessions, err := sqliteStore.ListSessions(ctx)
if err != nil {
t.Fatalf("Failed to list sessions: %v", err)
}
var childSession *store.Session
for _, s := range sessions {
if s.ParentSessionID == parentSessionID {
childSession = s
break
}
}
if childSession == nil {
t.Fatal("Child session not found")
}
// Verify overrides were applied
if childSession.SystemPrompt != "Override system prompt" {
t.Errorf("SystemPrompt override failed: got %s", childSession.SystemPrompt)
}
if childSession.AppendSystemPrompt != "Override append" {
t.Errorf("AppendSystemPrompt override failed: got %s", childSession.AppendSystemPrompt)
}
if childSession.CustomInstructions != "Override instructions" {
t.Errorf("CustomInstructions override failed: got %s", childSession.CustomInstructions)
}
if childSession.PermissionPromptTool != "override-tool" {
t.Errorf("PermissionPromptTool override failed: got %s", childSession.PermissionPromptTool)
}
// Check allowed tools
var allowedTools []string
if err := json.Unmarshal([]byte(childSession.AllowedTools), &allowedTools); err != nil {
t.Fatalf("Failed to unmarshal AllowedTools: %v", err)
}
if len(allowedTools) != 3 || allowedTools[0] != "override1" {
t.Errorf("AllowedTools override failed: got %v", allowedTools)
}
// Check disallowed tools
var disallowedTools []string
if err := json.Unmarshal([]byte(childSession.DisallowedTools), &disallowedTools); err != nil {
t.Fatalf("Failed to unmarshal DisallowedTools: %v", err)
}
if len(disallowedTools) != 2 || disallowedTools[0] != "override4" {
t.Errorf("DisallowedTools override failed: got %v", disallowedTools)
}
if childSession.MaxTurns != 5 {
t.Errorf("MaxTurns override failed: got %d", childSession.MaxTurns)
}
})
t.Run("MCPConfigOverride", func(t *testing.T) {
// Create parent session with MCP
parentSessionID := "parent-mcp-override"
parentSession := &store.Session{
ID: parentSessionID,
RunID: "run-mcp-override",
ClaudeSessionID: "claude-mcp-override",
Status: store.SessionStatusCompleted,
Query: "original",
Model: "claude-3-opus-20240229",
WorkingDir: "/tmp/test",
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
CompletedAt: &time.Time{},
}
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
t.Fatalf("Failed to create parent session: %v", err)
}
// Store original MCP servers
originalServers := []store.MCPServer{
{
SessionID: parentSessionID,
Name: "original-server",
Command: "original-cmd",
ArgsJSON: `["original"]`,
EnvJSON: `{"ORIGINAL": "true"}`,
},
}
if err := sqliteStore.StoreMCPServers(ctx, parentSessionID, originalServers); err != nil {
t.Fatalf("Failed to store MCP servers: %v", err)
}
// Continue with MCP override
overrideMCP := &claudecode.MCPConfig{
MCPServers: map[string]claudecode.MCPServer{
"override-server": {
Command: "override-cmd",
Args: []string{"override"},
Env: map[string]string{"OVERRIDE": "true"},
},
},
}
req := ContinueSessionConfig{
ParentSessionID: parentSessionID,
Query: "override query",
MCPConfig: overrideMCP,
}
_, _ = manager.ContinueSession(ctx, req)
// Expected to fail due to missing Claude binary
// Find the child session
sessions, err := sqliteStore.ListSessions(ctx)
if err != nil {
t.Fatalf("Failed to list sessions: %v", err)
}
var childSession *store.Session
for _, s := range sessions {
if s.ParentSessionID == parentSessionID {
childSession = s
break
}
}
if childSession == nil {
t.Fatal("Child session not found")
}
// Get MCP servers for child
childMCPServers, err := sqliteStore.GetMCPServers(ctx, childSession.ID)
if err != nil {
t.Fatalf("Failed to get child MCP servers: %v", err)
}
// Should have the override server, not the original
if len(childMCPServers) != 1 {
t.Fatalf("Expected 1 MCP server, got %d", len(childMCPServers))
}
server := childMCPServers[0]
if server.Name != "override-server" {
t.Errorf("MCP server name not overridden: got %s", server.Name)
}
if server.Command != "override-cmd" {
t.Errorf("MCP server command not overridden: got %s", server.Command)
}
})
}

View File

@@ -197,55 +197,84 @@ func (m *Manager) LaunchSession(ctx context.Context, config claudecode.SessionCo
func (m *Manager) monitorSession(ctx context.Context, sessionID, runID string, claudeSession *claudecode.Session, startTime time.Time, config claudecode.SessionConfig) {
// Get the session ID from the Claude session once available
var claudeSessionID string
for event := range claudeSession.Events {
// Store raw event for debugging
eventJSON, err := json.Marshal(event)
if err != nil {
slog.Error("failed to marshal event", "error", err)
} else {
if err := m.store.StoreRawEvent(ctx, sessionID, string(eventJSON)); err != nil {
slog.Debug("failed to store raw event", "error", err)
}
}
// Capture Claude session ID
if event.SessionID != "" && claudeSessionID == "" {
claudeSessionID = event.SessionID
// Note: Claude session ID captured for resume capability
slog.Debug("captured Claude session ID",
"session_id", sessionID,
"claude_session_id", claudeSessionID)
// Update database
update := store.SessionUpdate{
ClaudeSessionID: &claudeSessionID,
}
if err := m.store.UpdateSession(ctx, sessionID, update); err != nil {
slog.Error("failed to update session in database", "error", err)
eventLoop:
for {
select {
case <-ctx.Done():
// Context cancelled, stop processing
slog.Debug("monitorSession context cancelled, stopping event processing",
"session_id", sessionID)
return
case event, ok := <-claudeSession.Events:
if !ok {
// Channel closed, exit loop
break eventLoop
}
// Inject the pending query now that we have Claude session ID
if queryVal, ok := m.pendingQueries.LoadAndDelete(sessionID); ok {
if query, ok := queryVal.(string); ok && query != "" {
if err := m.injectQueryAsFirstEvent(ctx, sessionID, claudeSessionID, query); err != nil {
slog.Error("failed to inject query as first event",
"sessionID", sessionID,
"claudeSessionID", claudeSessionID,
"error", err)
// Check context before each database operation
if ctx.Err() != nil {
slog.Debug("context cancelled during event processing",
"session_id", sessionID)
return
}
// Store raw event for debugging
eventJSON, err := json.Marshal(event)
if err != nil {
slog.Error("failed to marshal event", "error", err)
} else {
if err := m.store.StoreRawEvent(ctx, sessionID, string(eventJSON)); err != nil {
slog.Debug("failed to store raw event", "error", err)
}
}
// Capture Claude session ID
if event.SessionID != "" && claudeSessionID == "" {
claudeSessionID = event.SessionID
// Note: Claude session ID captured for resume capability
slog.Debug("captured Claude session ID",
"session_id", sessionID,
"claude_session_id", claudeSessionID)
// Update database
update := store.SessionUpdate{
ClaudeSessionID: &claudeSessionID,
}
if err := m.store.UpdateSession(ctx, sessionID, update); err != nil {
slog.Error("failed to update session in database", "error", err)
}
// Inject the pending query now that we have Claude session ID
if queryVal, ok := m.pendingQueries.LoadAndDelete(sessionID); ok {
if query, ok := queryVal.(string); ok && query != "" {
if err := m.injectQueryAsFirstEvent(ctx, sessionID, claudeSessionID, query); err != nil {
slog.Error("failed to inject query as first event",
"sessionID", sessionID,
"claudeSessionID", claudeSessionID,
"error", err)
}
}
}
}
}
// Process and store event
if err := m.processStreamEvent(ctx, sessionID, claudeSessionID, event); err != nil {
slog.Error("failed to process stream event", "error", err)
// Process and store event
if err := m.processStreamEvent(ctx, sessionID, claudeSessionID, event); err != nil {
slog.Error("failed to process stream event", "error", err)
}
}
}
// Wait for session to complete
result, err := claudeSession.Wait()
// Check if context was cancelled before updating database
if ctx.Err() != nil {
slog.Debug("context cancelled, skipping final session updates",
"session_id", sessionID)
return
}
endTime := time.Now()
if err != nil {
m.updateSessionStatus(ctx, sessionID, StatusFailed, err.Error())
@@ -724,17 +753,64 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig
}
// Build config for resumed session
// Start with minimal required fields
// Start by inheriting ALL configuration from parent session
config := claudecode.SessionConfig{
Query: req.Query,
SessionID: parentSession.ClaudeSessionID, // This triggers --resume flag
OutputFormat: claudecode.OutputStreamJSON, // Always use streaming JSON
// Inherit Model and WorkingDir from parent session for database storage
Model: claudecode.Model(parentSession.Model),
WorkingDir: parentSession.WorkingDir,
Query: req.Query,
SessionID: parentSession.ClaudeSessionID, // This triggers --resume flag
OutputFormat: claudecode.OutputStreamJSON, // Always use streaming JSON
Model: claudecode.Model(parentSession.Model),
WorkingDir: parentSession.WorkingDir,
SystemPrompt: parentSession.SystemPrompt,
AppendSystemPrompt: parentSession.AppendSystemPrompt,
CustomInstructions: parentSession.CustomInstructions,
PermissionPromptTool: parentSession.PermissionPromptTool,
// MaxTurns intentionally NOT inherited - let it default or be specified
}
// Apply optional overrides
// Deserialize JSON arrays for tools
if parentSession.AllowedTools != "" {
var allowedTools []string
if err := json.Unmarshal([]byte(parentSession.AllowedTools), &allowedTools); err == nil {
config.AllowedTools = allowedTools
}
}
if parentSession.DisallowedTools != "" {
var disallowedTools []string
if err := json.Unmarshal([]byte(parentSession.DisallowedTools), &disallowedTools); err == nil {
config.DisallowedTools = disallowedTools
}
}
// Retrieve and inherit MCP configuration from parent session
mcpServers, err := m.store.GetMCPServers(ctx, req.ParentSessionID)
if err == nil && len(mcpServers) > 0 {
config.MCPConfig = &claudecode.MCPConfig{
MCPServers: make(map[string]claudecode.MCPServer),
}
for _, server := range mcpServers {
var args []string
var env map[string]string
if err := json.Unmarshal([]byte(server.ArgsJSON), &args); err != nil {
slog.Warn("failed to unmarshal MCP server args", "error", err, "server", server.Name)
args = []string{}
}
if err := json.Unmarshal([]byte(server.EnvJSON), &env); err != nil {
slog.Warn("failed to unmarshal MCP server env", "error", err, "server", server.Name)
env = map[string]string{}
}
config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{
Command: server.Command,
Args: args,
Env: env,
}
}
slog.Debug("inherited MCP servers from parent session",
"parent_session_id", req.ParentSessionID,
"mcp_server_count", len(mcpServers))
}
// Apply optional overrides (only if explicitly provided)
if req.SystemPrompt != "" {
config.SystemPrompt = req.SystemPrompt
}

View File

@@ -279,6 +279,9 @@ func TestContinueSession_CreatesNewSessionWithParentReference(t *testing.T) {
}
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
// Expect GetMCPServers call (even if it returns empty)
mockStore.EXPECT().GetMCPServers(gomock.Any(), "parent-1").Return([]store.MCPServer{}, nil)
// Expect session creation with parent reference
mockStore.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx interface{}, session *store.Session) error {
@@ -299,6 +302,9 @@ func TestContinueSession_CreatesNewSessionWithParentReference(t *testing.T) {
return nil
})
// Expect MCP servers to be stored (may or may not be called)
mockStore.EXPECT().StoreMCPServers(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
// Expect status update to running (we can't test the full flow without mocking Claude client)
// May be called twice if Claude fails to launch in background
mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
@@ -348,6 +354,9 @@ func TestContinueSession_HandlesOptionalOverrides(t *testing.T) {
}
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
// Expect GetMCPServers call (even if it returns empty)
mockStore.EXPECT().GetMCPServers(gomock.Any(), "parent-1").Return([]store.MCPServer{}, nil)
// Test with various overrides
req := ContinueSessionConfig{
ParentSessionID: "parent-1",
@@ -377,6 +386,9 @@ func TestContinueSession_HandlesOptionalOverrides(t *testing.T) {
return nil
})
// Expect MCP servers to be stored (if MCPConfig override is provided)
mockStore.EXPECT().StoreMCPServers(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
// Expect status update
// May be called twice if Claude fails to launch in background
mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()

View File

@@ -8,6 +8,7 @@ import (
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
@@ -53,6 +54,12 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
return nil, fmt.Errorf("failed to initialize schema: %w", err)
}
// Apply migrations (this must be called AFTER initSchema for both new and existing databases)
if err := store.applyMigrations(); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to apply migrations: %w", err)
}
slog.Info("SQLite store initialized", "path", dbPath)
return store, nil
}
@@ -74,7 +81,11 @@ func (s *SQLiteStore) initSchema() error {
working_dir TEXT,
max_turns INTEGER,
system_prompt TEXT,
append_system_prompt TEXT,
custom_instructions TEXT,
permission_prompt_tool TEXT,
allowed_tools TEXT,
disallowed_tools TEXT,
-- Runtime status
status TEXT NOT NULL DEFAULT 'starting',
@@ -168,14 +179,66 @@ func (s *SQLiteStore) initSchema() error {
return fmt.Errorf("failed to create schema: %w", err)
}
// Record schema version
// Record initial schema version
// For new databases, we start at version 3 since the schema includes all fields
_, err := s.db.Exec(`
INSERT OR IGNORE INTO schema_version (version, description)
VALUES (1, 'Initial schema with conversation events')
`)
if err != nil {
return err
}
// Mark new databases as having all migrations applied
_, err = s.db.Exec(`
INSERT OR IGNORE INTO schema_version (version, description)
VALUES (3, 'Initial schema includes all permission and tool fields')
`)
return err
}
// applyMigrations applies any pending database migrations
func (s *SQLiteStore) applyMigrations() error {
// Get current schema version
var currentVersion int
err := s.db.QueryRow("SELECT MAX(version) FROM schema_version").Scan(&currentVersion)
if err != nil {
return fmt.Errorf("failed to get current schema version: %w", err)
}
// Migration 2: Added constraint to ensure only resumable sessions can be parent sessions
// (This migration already exists in production databases)
// Migration 3: Add missing permission and tool fields
if currentVersion < 3 {
slog.Info("Applying migration 3: Add permission and tool fields")
_, err := s.db.Exec(`
-- Add missing columns to sessions table
ALTER TABLE sessions ADD COLUMN permission_prompt_tool TEXT;
ALTER TABLE sessions ADD COLUMN append_system_prompt TEXT;
ALTER TABLE sessions ADD COLUMN allowed_tools TEXT;
ALTER TABLE sessions ADD COLUMN disallowed_tools TEXT;
`)
if err != nil {
return fmt.Errorf("failed to apply migration 3: %w", err)
}
// Record migration
_, err = s.db.Exec(`
INSERT INTO schema_version (version, description)
VALUES (3, 'Add permission_prompt_tool, append_system_prompt, allowed_tools, disallowed_tools fields')
`)
if err != nil {
return fmt.Errorf("failed to record migration 3: %w", err)
}
slog.Info("Migration 3 applied successfully")
}
return nil
}
// Close closes the database connection
func (s *SQLiteStore) Close() error {
return s.db.Close()
@@ -186,15 +249,17 @@ func (s *SQLiteStore) CreateSession(ctx context.Context, session *Session) error
query := `
INSERT INTO sessions (
id, run_id, claude_session_id, parent_session_id,
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
permission_prompt_tool, allowed_tools, disallowed_tools,
status, created_at, last_activity_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := s.db.ExecContext(ctx, query,
session.ID, session.RunID, session.ClaudeSessionID, session.ParentSessionID,
session.Query, session.Summary, session.Model, session.WorkingDir, session.MaxTurns,
session.SystemPrompt, session.CustomInstructions,
session.SystemPrompt, session.AppendSystemPrompt, session.CustomInstructions,
session.PermissionPromptTool, session.AllowedTools, session.DisallowedTools,
session.Status, session.CreatedAt, session.LastActivityAt,
)
if err != nil {
@@ -284,14 +349,16 @@ func (s *SQLiteStore) UpdateSession(ctx context.Context, sessionID string, updat
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
query := `
SELECT id, run_id, claude_session_id, parent_session_id,
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
permission_prompt_tool, allowed_tools, disallowed_tools,
status, created_at, last_activity_at, completed_at,
cost_usd, total_tokens, duration_ms, num_turns, result_content, error_message
FROM sessions WHERE id = ?
`
var session Session
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, customInstructions sql.NullString
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, appendSystemPrompt, customInstructions sql.NullString
var permissionPromptTool, allowedTools, disallowedTools sql.NullString
var completedAt sql.NullTime
var costUSD sql.NullFloat64
var totalTokens, durationMS, numTurns sql.NullInt64
@@ -300,7 +367,8 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
err := s.db.QueryRowContext(ctx, query, sessionID).Scan(
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
&session.Query, &summary, &model, &workingDir, &session.MaxTurns,
&systemPrompt, &customInstructions,
&systemPrompt, &appendSystemPrompt, &customInstructions,
&permissionPromptTool, &allowedTools, &disallowedTools,
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
&costUSD, &totalTokens, &durationMS, &numTurns, &resultContent, &errorMessage,
)
@@ -318,7 +386,11 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
session.Model = model.String
session.WorkingDir = workingDir.String
session.SystemPrompt = systemPrompt.String
session.AppendSystemPrompt = appendSystemPrompt.String
session.CustomInstructions = customInstructions.String
session.PermissionPromptTool = permissionPromptTool.String
session.AllowedTools = allowedTools.String
session.DisallowedTools = disallowedTools.String
session.ResultContent = resultContent.String
session.ErrorMessage = errorMessage.String
if completedAt.Valid {
@@ -347,7 +419,8 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Session, error) {
query := `
SELECT id, run_id, claude_session_id, parent_session_id,
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
permission_prompt_tool, allowed_tools, disallowed_tools,
status, created_at, last_activity_at, completed_at,
cost_usd, total_tokens, duration_ms, num_turns, result_content, error_message
FROM sessions
@@ -355,7 +428,8 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
`
var session Session
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, customInstructions sql.NullString
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, appendSystemPrompt, customInstructions sql.NullString
var permissionPromptTool, allowedTools, disallowedTools sql.NullString
var completedAt sql.NullTime
var costUSD sql.NullFloat64
var totalTokens, durationMS, numTurns sql.NullInt64
@@ -364,7 +438,8 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
err := s.db.QueryRowContext(ctx, query, runID).Scan(
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
&session.Query, &summary, &model, &workingDir, &session.MaxTurns,
&systemPrompt, &customInstructions,
&systemPrompt, &appendSystemPrompt, &customInstructions,
&permissionPromptTool, &allowedTools, &disallowedTools,
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
&costUSD, &totalTokens, &durationMS, &numTurns, &resultContent, &errorMessage,
)
@@ -382,7 +457,11 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
session.Model = model.String
session.WorkingDir = workingDir.String
session.SystemPrompt = systemPrompt.String
session.AppendSystemPrompt = appendSystemPrompt.String
session.CustomInstructions = customInstructions.String
session.PermissionPromptTool = permissionPromptTool.String
session.AllowedTools = allowedTools.String
session.DisallowedTools = disallowedTools.String
session.ResultContent = resultContent.String
session.ErrorMessage = errorMessage.String
if completedAt.Valid {
@@ -411,7 +490,8 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
query := `
SELECT id, run_id, claude_session_id, parent_session_id,
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
permission_prompt_tool, allowed_tools, disallowed_tools,
status, created_at, last_activity_at, completed_at,
cost_usd, total_tokens, duration_ms, num_turns, result_content, error_message
FROM sessions
@@ -427,7 +507,8 @@ func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
var sessions []*Session
for rows.Next() {
var session Session
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, customInstructions sql.NullString
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, appendSystemPrompt, customInstructions sql.NullString
var permissionPromptTool, allowedTools, disallowedTools sql.NullString
var completedAt sql.NullTime
var costUSD sql.NullFloat64
var totalTokens, durationMS, numTurns sql.NullInt64
@@ -436,7 +517,8 @@ func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
err := rows.Scan(
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
&session.Query, &summary, &model, &workingDir, &session.MaxTurns,
&systemPrompt, &customInstructions,
&systemPrompt, &appendSystemPrompt, &customInstructions,
&permissionPromptTool, &allowedTools, &disallowedTools,
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
&costUSD, &totalTokens, &durationMS, &numTurns, &resultContent, &errorMessage,
)
@@ -451,7 +533,11 @@ func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
session.Model = model.String
session.WorkingDir = workingDir.String
session.SystemPrompt = systemPrompt.String
session.AppendSystemPrompt = appendSystemPrompt.String
session.CustomInstructions = customInstructions.String
session.PermissionPromptTool = permissionPromptTool.String
session.AllowedTools = allowedTools.String
session.DisallowedTools = disallowedTools.String
session.ResultContent = resultContent.String
session.ErrorMessage = errorMessage.String
if completedAt.Valid {
@@ -573,6 +659,7 @@ func (s *SQLiteStore) GetSessionConversation(ctx context.Context, sessionID stri
// Walk up the parent chain to get all related claude session IDs
claudeSessionIDs := []string{}
currentID := sessionID
isFirstSession := true
for currentID != "" {
var claudeSessionID sql.NullString
@@ -584,10 +671,16 @@ func (s *SQLiteStore) GetSessionConversation(ctx context.Context, sessionID stri
).Scan(&claudeSessionID, &parentID)
if err != nil {
if err == sql.ErrNoRows {
break // Session not found, stop walking
// If the requested session doesn't exist, return error
if isFirstSession {
return nil, fmt.Errorf("session not found: %s", sessionID)
}
// Otherwise, parent not found, just stop walking
break
}
return nil, fmt.Errorf("failed to get session: %w", err)
}
isFirstSession = false
// Add claude session ID if present (in reverse order for chronological events)
if claudeSessionID.Valid && claudeSessionID.String != "" {
@@ -892,6 +985,21 @@ func (s *SQLiteStore) CorrelateApprovalByToolID(ctx context.Context, sessionID s
// UpdateApprovalStatus updates the status of an approval
func (s *SQLiteStore) UpdateApprovalStatus(ctx context.Context, approvalID string, status string) error {
// Special handling for resolved status - don't overwrite approved/denied
if status == ApprovalStatusResolved {
query := `
UPDATE conversation_events
SET approval_status = ?
WHERE approval_id = ? AND approval_status = ?
`
_, err := s.db.ExecContext(ctx, query, status, approvalID, ApprovalStatusPending)
if err != nil {
return fmt.Errorf("failed to update approval status: %w", err)
}
return nil
}
// For approved/denied, always update
query := `
UPDATE conversation_events
SET approval_status = ?
@@ -935,6 +1043,7 @@ func (s *SQLiteStore) GetMCPServers(ctx context.Context, sessionID string) ([]MC
SELECT id, session_id, name, command, args_json, env_json
FROM mcp_servers
WHERE session_id = ?
ORDER BY id
`
rows, err := s.db.QueryContext(ctx, query, sessionID)
@@ -975,8 +1084,17 @@ func (s *SQLiteStore) StoreRawEvent(ctx context.Context, sessionID string, event
// Helper function to convert MCP config to store format
func MCPServersFromConfig(sessionID string, config map[string]claudecode.MCPServer) ([]MCPServer, error) {
// First, collect all server names and sort them for deterministic ordering
names := make([]string, 0, len(config))
for name := range config {
names = append(names, name)
}
// Sort names to ensure consistent ordering
sort.Strings(names)
servers := make([]MCPServer, 0, len(config))
for name, server := range config {
for _, name := range names {
server := config[name]
argsJSON, err := json.Marshal(server.Args)
if err != nil {
return nil, fmt.Errorf("failed to marshal args: %w", err)

View File

@@ -436,10 +436,10 @@ func TestGetSessionConversationWithParentChain(t *testing.T) {
})
t.Run("GetSessionConversation_NonExistentSession", func(t *testing.T) {
// Should return empty events for non-existent session
events, err := store.GetSessionConversation(ctx, "does-not-exist")
require.NoError(t, err)
require.Len(t, events, 0)
// Should return error for non-existent session
_, err := store.GetSessionConversation(ctx, "does-not-exist")
require.Error(t, err)
require.Contains(t, err.Error(), "session not found: does-not-exist")
})
t.Run("GetSessionConversation_NoParent", func(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package store
import (
"context"
"encoding/json"
"time"
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
@@ -42,27 +43,31 @@ type ConversationStore interface {
// Session represents a Claude Code session
type Session struct {
ID string
RunID string
ClaudeSessionID string
ParentSessionID string
Query string
Summary string
Model string
WorkingDir string
MaxTurns int
SystemPrompt string
CustomInstructions string
Status string
CreatedAt time.Time
LastActivityAt time.Time
CompletedAt *time.Time
CostUSD *float64
TotalTokens *int
DurationMS *int
NumTurns *int
ResultContent string
ErrorMessage string
ID string
RunID string
ClaudeSessionID string
ParentSessionID string
Query string
Summary string
Model string
WorkingDir string
MaxTurns int
SystemPrompt string
AppendSystemPrompt string // NEW: Append to system prompt
CustomInstructions string
PermissionPromptTool string // NEW: MCP tool for permission prompts
AllowedTools string // NEW: JSON array of allowed tools
DisallowedTools string // NEW: JSON array of disallowed tools
Status string
CreatedAt time.Time
LastActivityAt time.Time
CompletedAt *time.Time
CostUSD *float64
TotalTokens *int
DurationMS *int
NumTurns *int
ResultContent string
ErrorMessage string
}
// SessionUpdate contains fields that can be updated
@@ -148,17 +153,25 @@ const (
// NewSessionFromConfig creates a Session from Claude SessionConfig
func NewSessionFromConfig(id, runID string, config claudecode.SessionConfig) *Session {
// Convert slices to JSON for storage
allowedToolsJSON, _ := json.Marshal(config.AllowedTools)
disallowedToolsJSON, _ := json.Marshal(config.DisallowedTools)
return &Session{
ID: id,
RunID: runID,
Query: config.Query,
Model: string(config.Model),
WorkingDir: config.WorkingDir,
MaxTurns: config.MaxTurns,
SystemPrompt: config.SystemPrompt,
CustomInstructions: config.CustomInstructions,
Status: SessionStatusStarting,
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
ID: id,
RunID: runID,
Query: config.Query,
Model: string(config.Model),
WorkingDir: config.WorkingDir,
MaxTurns: config.MaxTurns,
SystemPrompt: config.SystemPrompt,
AppendSystemPrompt: config.AppendSystemPrompt,
CustomInstructions: config.CustomInstructions,
PermissionPromptTool: config.PermissionPromptTool,
AllowedTools: string(allowedToolsJSON),
DisallowedTools: string(disallowedToolsJSON),
Status: SessionStatusStarting,
CreatedAt: time.Now(),
LastActivityAt: time.Now(),
}
}

View File

@@ -32,10 +32,6 @@ type conversationModel struct {
resumeInput textinput.Model
showResumePrompt bool
// Parent session data for inheritance (stored during resume)
parentModel string
parentWorkingDir string
// Loading states
loading bool
error error
@@ -85,9 +81,6 @@ func (cm *conversationModel) setSession(sessionID string) {
cm.clearApprovalState()
cm.clearResumeState()
cm.stopPolling() // Stop any existing polling
// Clear parent data when switching sessions
cm.parentModel = ""
cm.parentWorkingDir = ""
// Reset scroll tracking
cm.wasAtBottom = true
}
@@ -217,16 +210,6 @@ func (cm *conversationModel) Update(msg tea.Msg, m *model) tea.Cmd {
cm.events = msg.events
cm.lastRefresh = time.Now()
// If this is a child session with missing data, use parent data stored during resume
if cm.session != nil && cm.session.ParentSessionID != "" {
if cm.session.Model == "" && cm.parentModel != "" {
cm.session.Model = cm.parentModel
}
if cm.session.WorkingDir == "" && cm.parentWorkingDir != "" {
cm.session.WorkingDir = cm.parentWorkingDir
}
}
// Cache the conversation for future use (if session and events are not nil)
if cm.session != nil && cm.events != nil {
m.conversationCache.put(cm.sessionID, cm.session, cm.events)
@@ -390,9 +373,6 @@ func (cm *conversationModel) updateResumeInput(msg tea.KeyMsg, m *model) tea.Cmd
if cm.session != nil && cm.sessionID != "" {
query := cm.resumeInput.Value()
if query != "" {
// Store parent session data for inheritance
cm.parentModel = cm.session.Model
cm.parentWorkingDir = cm.session.WorkingDir
return continueSession(m.daemonClient, cm.sessionID, query)
}
}