mirror of
https://github.com/humanlayer/humanlayer.git
synced 2025-08-20 19:01:22 +03:00
fix resumed sessions to inherit all configuration from parent (#253)
* Add missing configuration fields to sessions table - Add permission_prompt_tool, append_system_prompt, allowed_tools, and disallowed_tools fields to Session struct - Create database migration to add new columns to sessions table - Update CreateSession, GetSession, GetSessionByRunID, and ListSessions to handle new fields - Serialize tool arrays as JSON for database storage * Implement full configuration inheritance for resumed sessions - Inherit all configuration fields from parent session in ContinueSession (except MaxTurns which is intentionally not inherited) - Retrieve and inherit MCP server configurations from parent - Add comprehensive tests for field inheritance, MCP inheritance, and override behavior - Update existing tests to handle new mock expectations * Remove manual inheritance workaround from TUI Now that the backend properly inherits all configuration from parent sessions, remove the manual workaround that was storing and applying parentModel and parentWorkingDir fields. * Fix database migration ordering and versioning Move migration application after schema initialization to ensure proper execution order for both new and existing databases. Update migration version to 3 and ensure new databases start with the latest schema version to avoid unnecessary migrations. * Fix MCP server retrieval ordering for consistent test results Add ORDER BY clause to GetMCPServers query to ensure deterministic ordering when retrieving MCP servers from the database. This fixes the InheritsMCPServers test failure in CI where servers were being returned in non-deterministic order. * Fix MCP server ordering for consistent test results Ensure deterministic ordering when converting MCP servers from map to slice by sorting server names before iteration. This prevents test flakiness caused by map iteration order variations. * Fix integration test failures and improve test infrastructure - Enable integration tests in standard test suite by default - Add quiet mode support for integration tests to match unit test output - Fix scanner usage and connection handling in session tests - Add config override to prevent loading user configuration in tests - Filter approval poller events in subscription tests to reduce noise - Use in-memory database for test isolation * Fix race conditions and improve error handling - Fix approval status race condition: only update from pending to resolved, never overwrite approved/denied status - Return error for non-existent sessions instead of empty result - Fix listener double-close issue by tracking close state - Handle context cancellation properly in monitorSession to prevent 'database is closed' errors during shutdown - Update test expectations to match new error handling behavior * Fix CI test failure in TestSessionStateTransitionsIntegration Use temporary database file instead of :memory: to ensure all connections access the same database. SQLite in-memory databases are unique per connection, which was causing 'no such table' errors when different components tried to access the database. * formatting
This commit is contained in:
4
Makefile
4
Makefile
@@ -66,8 +66,8 @@ test-hlyr: ## Test hlyr CLI tool
|
||||
@$(MAKE) -C hlyr test VERBOSE=$(VERBOSE)
|
||||
|
||||
.PHONY: test-hld
|
||||
test-hld: ## Test hld daemon (unit tests only)
|
||||
@$(MAKE) -C hld test-unit VERBOSE=$(VERBOSE)
|
||||
test-hld: ## Test hld daemon (unit and integration tests)
|
||||
@$(MAKE) -C hld test VERBOSE=$(VERBOSE)
|
||||
|
||||
.PHONY: test-hld-integration
|
||||
test-hld-integration: ## Test hld daemon (including integration tests)
|
||||
|
||||
23
hld/Makefile
23
hld/Makefile
@@ -5,13 +5,32 @@ build:
|
||||
go build -o hld ./cmd/hld
|
||||
|
||||
# Run all tests
|
||||
test: test-unit test-integration
|
||||
test:
|
||||
@if [ -n "$$VERBOSE" ]; then \
|
||||
$(MAKE) test-unit test-integration; \
|
||||
else \
|
||||
$(MAKE) test-quiet; \
|
||||
fi
|
||||
|
||||
# Run all tests with quiet output
|
||||
test-quiet:
|
||||
@. ../hack/run_silent.sh && print_header "hld" "Daemon tests"
|
||||
@$(MAKE) test-unit-quiet
|
||||
@$(MAKE) test-integration-quiet
|
||||
|
||||
# Base test-unit target overridden below
|
||||
|
||||
# Run integration tests (requires build tag)
|
||||
test-integration: build
|
||||
CGO_LDFLAGS="-Wl,-w" go test -v -tags=integration -run Integration ./daemon/...
|
||||
@if [ -n "$$VERBOSE" ]; then \
|
||||
CGO_LDFLAGS="-Wl,-w" go test -v -tags=integration -run Integration ./daemon/...; \
|
||||
else \
|
||||
$(MAKE) test-integration-quiet; \
|
||||
fi
|
||||
|
||||
# Run integration tests with quiet output
|
||||
test-integration-quiet: build
|
||||
@. ../hack/run_silent.sh && run_silent_with_test_count "Integration tests passed" "CGO_LDFLAGS=\"-Wl,-w\" go test -json -tags=integration -run Integration ./daemon/..." "go"
|
||||
|
||||
# Run tests with race detection
|
||||
test-race:
|
||||
|
||||
@@ -134,12 +134,17 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
// Track if listener was already closed
|
||||
listenerClosed := &struct{ closed bool }{}
|
||||
|
||||
// Ensure cleanup on exit
|
||||
defer func() {
|
||||
if err := listener.Close(); err != nil {
|
||||
slog.Warn("failed to close listener", "error", err)
|
||||
if !listenerClosed.closed {
|
||||
if err := listener.Close(); err != nil {
|
||||
slog.Warn("failed to close listener", "error", err)
|
||||
}
|
||||
}
|
||||
if err := os.Remove(d.socketPath); err != nil {
|
||||
if err := os.Remove(d.socketPath); err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to remove socket file", "path", d.socketPath, "error", err)
|
||||
}
|
||||
if d.store != nil {
|
||||
@@ -196,6 +201,7 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||
if err := listener.Close(); err != nil {
|
||||
slog.Warn("error closing listener during shutdown", "error", err)
|
||||
}
|
||||
listenerClosed.closed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -231,13 +231,17 @@ func TestIntegrationContinueSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Continue the session
|
||||
// Continue the session with some overrides
|
||||
req := rpc.ContinueSessionRequest{
|
||||
SessionID: parentSessionID,
|
||||
Query: "follow up question",
|
||||
SystemPrompt: "You are helpful",
|
||||
CustomInstructions: "Be concise",
|
||||
MaxTurns: 3,
|
||||
SessionID: parentSessionID,
|
||||
Query: "follow up question",
|
||||
SystemPrompt: "You are helpful",
|
||||
AppendSystemPrompt: "Always be polite",
|
||||
CustomInstructions: "Be concise",
|
||||
PermissionPromptTool: "hlyr",
|
||||
AllowedTools: []string{"read", "write"},
|
||||
DisallowedTools: []string{"delete"},
|
||||
MaxTurns: 3,
|
||||
}
|
||||
|
||||
result, err := sendRPC(t, "continueSession", req)
|
||||
@@ -284,9 +288,32 @@ func TestIntegrationContinueSession(t *testing.T) {
|
||||
if newSession.SystemPrompt != "You are helpful" {
|
||||
t.Errorf("Expected system prompt override, got %s", newSession.SystemPrompt)
|
||||
}
|
||||
if newSession.AppendSystemPrompt != "Always be polite" {
|
||||
t.Errorf("Expected append system prompt, got %s", newSession.AppendSystemPrompt)
|
||||
}
|
||||
if newSession.CustomInstructions != "Be concise" {
|
||||
t.Errorf("Expected custom instructions override, got %s", newSession.CustomInstructions)
|
||||
}
|
||||
if newSession.PermissionPromptTool != "hlyr" {
|
||||
t.Errorf("Expected permission prompt tool, got %s", newSession.PermissionPromptTool)
|
||||
}
|
||||
|
||||
// Check allowed tools
|
||||
var allowedTools []string
|
||||
if err := json.Unmarshal([]byte(newSession.AllowedTools), &allowedTools); err == nil {
|
||||
if len(allowedTools) != 2 || allowedTools[0] != "read" || allowedTools[1] != "write" {
|
||||
t.Errorf("Expected allowed tools [read, write], got %v", allowedTools)
|
||||
}
|
||||
}
|
||||
|
||||
// Check disallowed tools
|
||||
var disallowedTools []string
|
||||
if err := json.Unmarshal([]byte(newSession.DisallowedTools), &disallowedTools); err == nil {
|
||||
if len(disallowedTools) != 1 || disallowedTools[0] != "delete" {
|
||||
t.Errorf("Expected disallowed tools [delete], got %v", disallowedTools)
|
||||
}
|
||||
}
|
||||
|
||||
if newSession.MaxTurns != 3 {
|
||||
t.Errorf("Expected max turns 3, got %d", newSession.MaxTurns)
|
||||
}
|
||||
|
||||
@@ -193,7 +193,9 @@ func TestGetConversationIntegration(t *testing.T) {
|
||||
// because it needs to look up the claude_session_id first
|
||||
_, err := daemonClient.GetConversation("nonexistent-session")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to get conversation")
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "failed to get conversation")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetSessionState for nonexistent session", func(t *testing.T) {
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -24,6 +26,23 @@ func TestSessionLaunchIntegration(t *testing.T) {
|
||||
// Set environment for test
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
defer os.Unsetenv("HUMANLAYER_DAEMON_SOCKET")
|
||||
// Disable API key to prevent approval manager issues in tests
|
||||
os.Setenv("HUMANLAYER_API_KEY", "")
|
||||
defer os.Unsetenv("HUMANLAYER_API_KEY")
|
||||
// Use a temporary config directory to avoid loading user's config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
defer os.Unsetenv("XDG_CONFIG_HOME")
|
||||
|
||||
// Create an empty config file to override any existing config
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create config dir: %v", err)
|
||||
}
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
if err := os.WriteFile(configFile, []byte(`{}`), 0644); err != nil {
|
||||
t.Fatalf("Failed to create empty config file: %v", err)
|
||||
}
|
||||
|
||||
// Create and start daemon
|
||||
daemon, err := New()
|
||||
@@ -57,15 +76,14 @@ func TestSessionLaunchIntegration(t *testing.T) {
|
||||
t.Fatalf("Daemon socket not ready after 5 seconds: %v", err)
|
||||
}
|
||||
|
||||
// Connect to daemon
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to daemon: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Test launching a session
|
||||
t.Run("LaunchSession", func(t *testing.T) {
|
||||
// Connect to daemon
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to daemon: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
// Check if claude binary exists
|
||||
if _, err := os.Stat("/usr/bin/claude"); os.IsNotExist(err) {
|
||||
t.Skip("Claude binary not found, skipping launch test")
|
||||
@@ -98,14 +116,17 @@ func TestSessionLaunchIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Read response
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response: %v", err)
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large responses
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("Scanner error: %v", err)
|
||||
}
|
||||
t.Fatal("Failed to read response")
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(buf[:n], &resp); err != nil {
|
||||
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
@@ -140,6 +161,12 @@ func TestSessionLaunchIntegration(t *testing.T) {
|
||||
|
||||
// Test listing sessions
|
||||
t.Run("ListSessions", func(t *testing.T) {
|
||||
// Connect to daemon
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to daemon: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
// Send ListSessions request
|
||||
reqData, _ := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
@@ -153,14 +180,17 @@ func TestSessionLaunchIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Read response
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response: %v", err)
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large responses
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("Scanner error: %v", err)
|
||||
}
|
||||
t.Fatal("Failed to read response")
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(buf[:n], &resp); err != nil {
|
||||
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
@@ -218,6 +248,23 @@ func TestConcurrentSessions(t *testing.T) {
|
||||
// Set environment for test
|
||||
os.Setenv("HUMANLAYER_DAEMON_SOCKET", socketPath)
|
||||
defer os.Unsetenv("HUMANLAYER_DAEMON_SOCKET")
|
||||
// Disable API key to prevent approval manager issues in tests
|
||||
os.Setenv("HUMANLAYER_API_KEY", "")
|
||||
defer os.Unsetenv("HUMANLAYER_API_KEY")
|
||||
// Use a temporary config directory to avoid loading user's config
|
||||
tempDir := t.TempDir()
|
||||
os.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
defer os.Unsetenv("XDG_CONFIG_HOME")
|
||||
|
||||
// Create an empty config file to override any existing config
|
||||
configDir := filepath.Join(tempDir, "humanlayer")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create config dir: %v", err)
|
||||
}
|
||||
configFile := filepath.Join(configDir, "humanlayer.json")
|
||||
if err := os.WriteFile(configFile, []byte(`{}`), 0644); err != nil {
|
||||
t.Fatalf("Failed to create empty config file: %v", err)
|
||||
}
|
||||
|
||||
// Create and start daemon
|
||||
daemon, err := New()
|
||||
@@ -280,15 +327,19 @@ func TestConcurrentSessions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Read response
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("session %d: failed to read: %w", sessionNum, err)
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
results <- fmt.Errorf("session %d: scanner error: %w", sessionNum, err)
|
||||
} else {
|
||||
results <- fmt.Errorf("session %d: failed to read response", sessionNum)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(buf[:n], &resp); err != nil {
|
||||
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
|
||||
results <- fmt.Errorf("session %d: failed to parse: %w", sessionNum, err)
|
||||
return
|
||||
}
|
||||
@@ -332,14 +383,17 @@ func TestConcurrentSessions(t *testing.T) {
|
||||
t.Fatalf("Failed to send list request: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 8192)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read list response: %v", err)
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // 1MB buffer for large responses
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("Failed to read list response: %v", err)
|
||||
}
|
||||
t.Fatal("Failed to read list response: no data")
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(buf[:n], &resp); err != nil {
|
||||
if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse list response: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ package daemon
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -28,13 +29,22 @@ func TestSessionStateTransitionsIntegration(t *testing.T) {
|
||||
// Create temporary socket path for test
|
||||
socketPath := testutil.SocketPath(t, "session-state")
|
||||
|
||||
// Create in-memory store
|
||||
testStore, err := store.NewSQLiteStore(":memory:")
|
||||
// Use a temporary database file instead of :memory: to ensure all connections
|
||||
// access the same database (in-memory databases are unique per connection)
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "test.db")
|
||||
|
||||
// Create the store
|
||||
testStore, err := store.NewSQLiteStore(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test store: %v", err)
|
||||
}
|
||||
defer testStore.Close()
|
||||
|
||||
// Set environment variables to ensure consistent test behavior
|
||||
t.Setenv("HUMANLAYER_DATABASE_PATH", dbPath)
|
||||
t.Setenv("HUMANLAYER_API_KEY", "test-key")
|
||||
|
||||
// Create event bus
|
||||
eventBus := bus.NewEventBus()
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@ func TestDaemonSubscriptionIntegration(t *testing.T) {
|
||||
socketPath := testutil.CreateTestSocket(t)
|
||||
t.Setenv("HUMANLAYER_SOCKET_PATH", socketPath)
|
||||
t.Setenv("HUMANLAYER_LOG_LEVEL", "error")
|
||||
// Explicitly disable API key to prevent approval polling
|
||||
t.Setenv("HUMANLAYER_API_KEY", "")
|
||||
// Use in-memory database for tests
|
||||
t.Setenv("HUMANLAYER_DATABASE_PATH", ":memory:")
|
||||
|
||||
// Create and start daemon
|
||||
daemon, err := New()
|
||||
@@ -131,6 +135,14 @@ func TestDaemonSubscriptionIntegration(t *testing.T) {
|
||||
t.Fatalf("Client 2 failed to subscribe: %v", err)
|
||||
}
|
||||
|
||||
// Clear any initial events from approval poller
|
||||
select {
|
||||
case <-eventChan1:
|
||||
// Discard any initial events
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// No initial events, continue
|
||||
}
|
||||
|
||||
// Publish different events
|
||||
daemon.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventNewApproval,
|
||||
@@ -166,6 +178,16 @@ func TestDaemonSubscriptionIntegration(t *testing.T) {
|
||||
select {
|
||||
case notification, ok := <-eventChan1:
|
||||
if ok && notification.Event.Type != "" {
|
||||
// Check if this is an approval poller event (has count/total/type fields)
|
||||
data := notification.Event.Data
|
||||
if _, hasCount := data["count"]; hasCount {
|
||||
if _, hasTotal := data["total"]; hasTotal {
|
||||
if _, hasType := data["type"]; hasType {
|
||||
// This is an approval poller event, ignore it
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Errorf("Client 1 unexpectedly received event - Type: %q, Data: %+v", notification.Event.Type, notification.Event.Data)
|
||||
}
|
||||
// If channel is closed (!ok) or empty event, that's fine during cleanup
|
||||
|
||||
471
hld/session/continue_inheritance_test.go
Normal file
471
hld/session/continue_inheritance_test.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
)
|
||||
|
||||
func TestContinueSessionInheritance(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test components
|
||||
eventBus := bus.NewEventBus()
|
||||
sqliteStore, err := store.NewSQLiteStore(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create store: %v", err)
|
||||
}
|
||||
defer func() { _ = sqliteStore.Close() }()
|
||||
|
||||
manager, err := NewManager(eventBus, sqliteStore)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
t.Run("InheritsAllConfigurationFields", func(t *testing.T) {
|
||||
// Create parent session with full configuration
|
||||
parentSessionID := "parent-full-config"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-parent",
|
||||
ClaudeSessionID: "claude-parent",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original query",
|
||||
Model: "claude-3-opus-20240229",
|
||||
WorkingDir: "/tmp/test",
|
||||
MaxTurns: 10,
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
AppendSystemPrompt: "Be concise",
|
||||
CustomInstructions: "Follow best practices",
|
||||
PermissionPromptTool: "hlyr",
|
||||
AllowedTools: `["tool1", "tool2"]`,
|
||||
DisallowedTools: `["tool3"]`,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Continue session with minimal config (only required fields)
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: parentSessionID,
|
||||
Query: "follow up question",
|
||||
}
|
||||
|
||||
// This will fail because Claude binary doesn't exist in test env,
|
||||
// but we can still check that the config was built correctly
|
||||
_, _ = manager.ContinueSession(ctx, req)
|
||||
|
||||
// We expect it to fail at launch, but let's check the session was created with inherited config
|
||||
sessions, err := sqliteStore.ListSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list sessions: %v", err)
|
||||
}
|
||||
|
||||
// Should have parent and child sessions
|
||||
if len(sessions) < 2 {
|
||||
t.Fatalf("Expected at least 2 sessions, got %d", len(sessions))
|
||||
}
|
||||
|
||||
// Find the child session (most recent)
|
||||
var childSession *store.Session
|
||||
for _, s := range sessions {
|
||||
if s.ParentSessionID == parentSessionID {
|
||||
childSession = s
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if childSession == nil {
|
||||
t.Fatal("Child session not found")
|
||||
}
|
||||
|
||||
// Verify all fields were inherited
|
||||
if childSession.Model != parentSession.Model {
|
||||
t.Errorf("Model not inherited: got %s, want %s", childSession.Model, parentSession.Model)
|
||||
}
|
||||
if childSession.WorkingDir != parentSession.WorkingDir {
|
||||
t.Errorf("WorkingDir not inherited: got %s, want %s", childSession.WorkingDir, parentSession.WorkingDir)
|
||||
}
|
||||
if childSession.SystemPrompt != parentSession.SystemPrompt {
|
||||
t.Errorf("SystemPrompt not inherited: got %s, want %s", childSession.SystemPrompt, parentSession.SystemPrompt)
|
||||
}
|
||||
if childSession.AppendSystemPrompt != parentSession.AppendSystemPrompt {
|
||||
t.Errorf("AppendSystemPrompt not inherited: got %s, want %s", childSession.AppendSystemPrompt, parentSession.AppendSystemPrompt)
|
||||
}
|
||||
if childSession.CustomInstructions != parentSession.CustomInstructions {
|
||||
t.Errorf("CustomInstructions not inherited: got %s, want %s", childSession.CustomInstructions, parentSession.CustomInstructions)
|
||||
}
|
||||
if childSession.PermissionPromptTool != parentSession.PermissionPromptTool {
|
||||
t.Errorf("PermissionPromptTool not inherited: got %s, want %s", childSession.PermissionPromptTool, parentSession.PermissionPromptTool)
|
||||
}
|
||||
// Compare allowed tools (deserialize to compare content, not formatting)
|
||||
var childAllowed, parentAllowed []string
|
||||
if err := json.Unmarshal([]byte(childSession.AllowedTools), &childAllowed); err != nil {
|
||||
t.Fatalf("Failed to unmarshal child allowed tools: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(parentSession.AllowedTools), &parentAllowed); err != nil {
|
||||
t.Fatalf("Failed to unmarshal parent allowed tools: %v", err)
|
||||
}
|
||||
if len(childAllowed) != len(parentAllowed) {
|
||||
t.Errorf("AllowedTools length mismatch: got %d, want %d", len(childAllowed), len(parentAllowed))
|
||||
} else {
|
||||
for i, tool := range childAllowed {
|
||||
if tool != parentAllowed[i] {
|
||||
t.Errorf("AllowedTools[%d] not inherited: got %s, want %s", i, tool, parentAllowed[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compare disallowed tools
|
||||
var childDisallowed, parentDisallowed []string
|
||||
if err := json.Unmarshal([]byte(childSession.DisallowedTools), &childDisallowed); err != nil {
|
||||
t.Fatalf("Failed to unmarshal child disallowed tools: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(parentSession.DisallowedTools), &parentDisallowed); err != nil {
|
||||
t.Fatalf("Failed to unmarshal parent disallowed tools: %v", err)
|
||||
}
|
||||
if len(childDisallowed) != len(parentDisallowed) {
|
||||
t.Errorf("DisallowedTools length mismatch: got %d, want %d", len(childDisallowed), len(parentDisallowed))
|
||||
} else {
|
||||
for i, tool := range childDisallowed {
|
||||
if tool != parentDisallowed[i] {
|
||||
t.Errorf("DisallowedTools[%d] not inherited: got %s, want %s", i, tool, parentDisallowed[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MaxTurns should NOT be inherited (as per spec)
|
||||
if childSession.MaxTurns == parentSession.MaxTurns {
|
||||
t.Error("MaxTurns should not be inherited")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InheritsMCPServers", func(t *testing.T) {
|
||||
// Create parent session
|
||||
parentSessionID := "parent-mcp"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-mcp",
|
||||
ClaudeSessionID: "claude-mcp",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "mcp query",
|
||||
Model: "claude-3-opus-20240229",
|
||||
WorkingDir: "/tmp/test",
|
||||
SystemPrompt: "Test prompt",
|
||||
PermissionPromptTool: "hlyr",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Store MCP servers for parent
|
||||
mcpServers := []store.MCPServer{
|
||||
{
|
||||
SessionID: parentSessionID,
|
||||
Name: "test-server-1",
|
||||
Command: "node",
|
||||
ArgsJSON: `["server1.js", "--port", "3000"]`,
|
||||
EnvJSON: `{"NODE_ENV": "test", "API_KEY": "secret"}`,
|
||||
},
|
||||
{
|
||||
SessionID: parentSessionID,
|
||||
Name: "test-server-2",
|
||||
Command: "python",
|
||||
ArgsJSON: `["server2.py"]`,
|
||||
EnvJSON: `{"PYTHONPATH": "/usr/lib"}`,
|
||||
},
|
||||
}
|
||||
if err := sqliteStore.StoreMCPServers(ctx, parentSessionID, mcpServers); err != nil {
|
||||
t.Fatalf("Failed to store MCP servers: %v", err)
|
||||
}
|
||||
|
||||
// Continue session
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: parentSessionID,
|
||||
Query: "mcp follow up",
|
||||
}
|
||||
|
||||
_, _ = manager.ContinueSession(ctx, req)
|
||||
// Expected to fail due to missing Claude binary
|
||||
|
||||
// Find the child session
|
||||
sessions, err := sqliteStore.ListSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list sessions: %v", err)
|
||||
}
|
||||
|
||||
var childSession *store.Session
|
||||
for _, s := range sessions {
|
||||
if s.ParentSessionID == parentSessionID {
|
||||
childSession = s
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if childSession == nil {
|
||||
t.Fatal("Child session not found")
|
||||
}
|
||||
|
||||
// Get MCP servers for child session
|
||||
childMCPServers, err := sqliteStore.GetMCPServers(ctx, childSession.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get child MCP servers: %v", err)
|
||||
}
|
||||
|
||||
// Should have inherited the MCP servers
|
||||
if len(childMCPServers) != len(mcpServers) {
|
||||
t.Errorf("MCP servers not inherited: got %d, want %d", len(childMCPServers), len(mcpServers))
|
||||
}
|
||||
|
||||
// Verify server details (accounting for HUMANLAYER_RUN_ID being added)
|
||||
for i, server := range childMCPServers {
|
||||
if server.Name != mcpServers[i].Name {
|
||||
t.Errorf("MCP server %d name mismatch: got %s, want %s", i, server.Name, mcpServers[i].Name)
|
||||
}
|
||||
if server.Command != mcpServers[i].Command {
|
||||
t.Errorf("MCP server %d command mismatch: got %s, want %s", i, server.Command, mcpServers[i].Command)
|
||||
}
|
||||
|
||||
// Compare args (deserialize to compare content)
|
||||
var childArgs, parentArgs []string
|
||||
if err := json.Unmarshal([]byte(server.ArgsJSON), &childArgs); err != nil {
|
||||
t.Fatalf("Failed to unmarshal child args: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(mcpServers[i].ArgsJSON), &parentArgs); err != nil {
|
||||
t.Fatalf("Failed to unmarshal parent args: %v", err)
|
||||
}
|
||||
if len(childArgs) != len(parentArgs) {
|
||||
t.Errorf("MCP server %d args length mismatch", i)
|
||||
} else {
|
||||
for j, arg := range childArgs {
|
||||
if arg != parentArgs[j] {
|
||||
t.Errorf("MCP server %d arg[%d] mismatch: got %s, want %s", i, j, arg, parentArgs[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compare env (deserialize and check that parent env is subset of child env)
|
||||
var childEnv, parentEnv map[string]string
|
||||
if err := json.Unmarshal([]byte(server.EnvJSON), &childEnv); err != nil {
|
||||
t.Fatalf("Failed to unmarshal child env: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(mcpServers[i].EnvJSON), &parentEnv); err != nil {
|
||||
t.Fatalf("Failed to unmarshal parent env: %v", err)
|
||||
}
|
||||
|
||||
// Child should have all parent env vars plus HUMANLAYER_RUN_ID
|
||||
for key, val := range parentEnv {
|
||||
if childEnv[key] != val {
|
||||
t.Errorf("MCP server %d env[%s] mismatch: got %s, want %s", i, key, childEnv[key], val)
|
||||
}
|
||||
}
|
||||
|
||||
// Should have HUMANLAYER_RUN_ID added
|
||||
if _, ok := childEnv["HUMANLAYER_RUN_ID"]; !ok {
|
||||
t.Errorf("MCP server %d missing HUMANLAYER_RUN_ID in env", i)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OverridesWorkCorrectly", func(t *testing.T) {
|
||||
// Create parent session
|
||||
parentSessionID := "parent-override"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-override",
|
||||
ClaudeSessionID: "claude-override",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original",
|
||||
Model: "claude-3-opus-20240229",
|
||||
WorkingDir: "/tmp/test",
|
||||
SystemPrompt: "Original system prompt",
|
||||
AppendSystemPrompt: "Original append",
|
||||
CustomInstructions: "Original instructions",
|
||||
PermissionPromptTool: "original-tool",
|
||||
AllowedTools: `["original1", "original2"]`,
|
||||
DisallowedTools: `["original3"]`,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Continue with overrides
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: parentSessionID,
|
||||
Query: "override query",
|
||||
SystemPrompt: "Override system prompt",
|
||||
AppendSystemPrompt: "Override append",
|
||||
CustomInstructions: "Override instructions",
|
||||
PermissionPromptTool: "override-tool",
|
||||
AllowedTools: []string{"override1", "override2", "override3"},
|
||||
DisallowedTools: []string{"override4", "override5"},
|
||||
MaxTurns: 5,
|
||||
}
|
||||
|
||||
_, _ = manager.ContinueSession(ctx, req)
|
||||
// Expected to fail due to missing Claude binary
|
||||
|
||||
// Find the child session
|
||||
sessions, err := sqliteStore.ListSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list sessions: %v", err)
|
||||
}
|
||||
|
||||
var childSession *store.Session
|
||||
for _, s := range sessions {
|
||||
if s.ParentSessionID == parentSessionID {
|
||||
childSession = s
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if childSession == nil {
|
||||
t.Fatal("Child session not found")
|
||||
}
|
||||
|
||||
// Verify overrides were applied
|
||||
if childSession.SystemPrompt != "Override system prompt" {
|
||||
t.Errorf("SystemPrompt override failed: got %s", childSession.SystemPrompt)
|
||||
}
|
||||
if childSession.AppendSystemPrompt != "Override append" {
|
||||
t.Errorf("AppendSystemPrompt override failed: got %s", childSession.AppendSystemPrompt)
|
||||
}
|
||||
if childSession.CustomInstructions != "Override instructions" {
|
||||
t.Errorf("CustomInstructions override failed: got %s", childSession.CustomInstructions)
|
||||
}
|
||||
if childSession.PermissionPromptTool != "override-tool" {
|
||||
t.Errorf("PermissionPromptTool override failed: got %s", childSession.PermissionPromptTool)
|
||||
}
|
||||
|
||||
// Check allowed tools
|
||||
var allowedTools []string
|
||||
if err := json.Unmarshal([]byte(childSession.AllowedTools), &allowedTools); err != nil {
|
||||
t.Fatalf("Failed to unmarshal AllowedTools: %v", err)
|
||||
}
|
||||
if len(allowedTools) != 3 || allowedTools[0] != "override1" {
|
||||
t.Errorf("AllowedTools override failed: got %v", allowedTools)
|
||||
}
|
||||
|
||||
// Check disallowed tools
|
||||
var disallowedTools []string
|
||||
if err := json.Unmarshal([]byte(childSession.DisallowedTools), &disallowedTools); err != nil {
|
||||
t.Fatalf("Failed to unmarshal DisallowedTools: %v", err)
|
||||
}
|
||||
if len(disallowedTools) != 2 || disallowedTools[0] != "override4" {
|
||||
t.Errorf("DisallowedTools override failed: got %v", disallowedTools)
|
||||
}
|
||||
|
||||
if childSession.MaxTurns != 5 {
|
||||
t.Errorf("MaxTurns override failed: got %d", childSession.MaxTurns)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MCPConfigOverride", func(t *testing.T) {
|
||||
// Create parent session with MCP
|
||||
parentSessionID := "parent-mcp-override"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-mcp-override",
|
||||
ClaudeSessionID: "claude-mcp-override",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original",
|
||||
Model: "claude-3-opus-20240229",
|
||||
WorkingDir: "/tmp/test",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
if err := sqliteStore.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Store original MCP servers
|
||||
originalServers := []store.MCPServer{
|
||||
{
|
||||
SessionID: parentSessionID,
|
||||
Name: "original-server",
|
||||
Command: "original-cmd",
|
||||
ArgsJSON: `["original"]`,
|
||||
EnvJSON: `{"ORIGINAL": "true"}`,
|
||||
},
|
||||
}
|
||||
if err := sqliteStore.StoreMCPServers(ctx, parentSessionID, originalServers); err != nil {
|
||||
t.Fatalf("Failed to store MCP servers: %v", err)
|
||||
}
|
||||
|
||||
// Continue with MCP override
|
||||
overrideMCP := &claudecode.MCPConfig{
|
||||
MCPServers: map[string]claudecode.MCPServer{
|
||||
"override-server": {
|
||||
Command: "override-cmd",
|
||||
Args: []string{"override"},
|
||||
Env: map[string]string{"OVERRIDE": "true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: parentSessionID,
|
||||
Query: "override query",
|
||||
MCPConfig: overrideMCP,
|
||||
}
|
||||
|
||||
_, _ = manager.ContinueSession(ctx, req)
|
||||
// Expected to fail due to missing Claude binary
|
||||
|
||||
// Find the child session
|
||||
sessions, err := sqliteStore.ListSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list sessions: %v", err)
|
||||
}
|
||||
|
||||
var childSession *store.Session
|
||||
for _, s := range sessions {
|
||||
if s.ParentSessionID == parentSessionID {
|
||||
childSession = s
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if childSession == nil {
|
||||
t.Fatal("Child session not found")
|
||||
}
|
||||
|
||||
// Get MCP servers for child
|
||||
childMCPServers, err := sqliteStore.GetMCPServers(ctx, childSession.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get child MCP servers: %v", err)
|
||||
}
|
||||
|
||||
// Should have the override server, not the original
|
||||
if len(childMCPServers) != 1 {
|
||||
t.Fatalf("Expected 1 MCP server, got %d", len(childMCPServers))
|
||||
}
|
||||
|
||||
server := childMCPServers[0]
|
||||
if server.Name != "override-server" {
|
||||
t.Errorf("MCP server name not overridden: got %s", server.Name)
|
||||
}
|
||||
if server.Command != "override-cmd" {
|
||||
t.Errorf("MCP server command not overridden: got %s", server.Command)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -197,55 +197,84 @@ func (m *Manager) LaunchSession(ctx context.Context, config claudecode.SessionCo
|
||||
func (m *Manager) monitorSession(ctx context.Context, sessionID, runID string, claudeSession *claudecode.Session, startTime time.Time, config claudecode.SessionConfig) {
|
||||
// Get the session ID from the Claude session once available
|
||||
var claudeSessionID string
|
||||
for event := range claudeSession.Events {
|
||||
// Store raw event for debugging
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal event", "error", err)
|
||||
} else {
|
||||
if err := m.store.StoreRawEvent(ctx, sessionID, string(eventJSON)); err != nil {
|
||||
slog.Debug("failed to store raw event", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Capture Claude session ID
|
||||
if event.SessionID != "" && claudeSessionID == "" {
|
||||
claudeSessionID = event.SessionID
|
||||
// Note: Claude session ID captured for resume capability
|
||||
slog.Debug("captured Claude session ID",
|
||||
"session_id", sessionID,
|
||||
"claude_session_id", claudeSessionID)
|
||||
|
||||
// Update database
|
||||
update := store.SessionUpdate{
|
||||
ClaudeSessionID: &claudeSessionID,
|
||||
}
|
||||
if err := m.store.UpdateSession(ctx, sessionID, update); err != nil {
|
||||
slog.Error("failed to update session in database", "error", err)
|
||||
eventLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, stop processing
|
||||
slog.Debug("monitorSession context cancelled, stopping event processing",
|
||||
"session_id", sessionID)
|
||||
return
|
||||
case event, ok := <-claudeSession.Events:
|
||||
if !ok {
|
||||
// Channel closed, exit loop
|
||||
break eventLoop
|
||||
}
|
||||
|
||||
// Inject the pending query now that we have Claude session ID
|
||||
if queryVal, ok := m.pendingQueries.LoadAndDelete(sessionID); ok {
|
||||
if query, ok := queryVal.(string); ok && query != "" {
|
||||
if err := m.injectQueryAsFirstEvent(ctx, sessionID, claudeSessionID, query); err != nil {
|
||||
slog.Error("failed to inject query as first event",
|
||||
"sessionID", sessionID,
|
||||
"claudeSessionID", claudeSessionID,
|
||||
"error", err)
|
||||
// Check context before each database operation
|
||||
if ctx.Err() != nil {
|
||||
slog.Debug("context cancelled during event processing",
|
||||
"session_id", sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
// Store raw event for debugging
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal event", "error", err)
|
||||
} else {
|
||||
if err := m.store.StoreRawEvent(ctx, sessionID, string(eventJSON)); err != nil {
|
||||
slog.Debug("failed to store raw event", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Capture Claude session ID
|
||||
if event.SessionID != "" && claudeSessionID == "" {
|
||||
claudeSessionID = event.SessionID
|
||||
// Note: Claude session ID captured for resume capability
|
||||
slog.Debug("captured Claude session ID",
|
||||
"session_id", sessionID,
|
||||
"claude_session_id", claudeSessionID)
|
||||
|
||||
// Update database
|
||||
update := store.SessionUpdate{
|
||||
ClaudeSessionID: &claudeSessionID,
|
||||
}
|
||||
if err := m.store.UpdateSession(ctx, sessionID, update); err != nil {
|
||||
slog.Error("failed to update session in database", "error", err)
|
||||
}
|
||||
|
||||
// Inject the pending query now that we have Claude session ID
|
||||
if queryVal, ok := m.pendingQueries.LoadAndDelete(sessionID); ok {
|
||||
if query, ok := queryVal.(string); ok && query != "" {
|
||||
if err := m.injectQueryAsFirstEvent(ctx, sessionID, claudeSessionID, query); err != nil {
|
||||
slog.Error("failed to inject query as first event",
|
||||
"sessionID", sessionID,
|
||||
"claudeSessionID", claudeSessionID,
|
||||
"error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process and store event
|
||||
if err := m.processStreamEvent(ctx, sessionID, claudeSessionID, event); err != nil {
|
||||
slog.Error("failed to process stream event", "error", err)
|
||||
// Process and store event
|
||||
if err := m.processStreamEvent(ctx, sessionID, claudeSessionID, event); err != nil {
|
||||
slog.Error("failed to process stream event", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for session to complete
|
||||
result, err := claudeSession.Wait()
|
||||
|
||||
// Check if context was cancelled before updating database
|
||||
if ctx.Err() != nil {
|
||||
slog.Debug("context cancelled, skipping final session updates",
|
||||
"session_id", sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
if err != nil {
|
||||
m.updateSessionStatus(ctx, sessionID, StatusFailed, err.Error())
|
||||
@@ -724,17 +753,64 @@ func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig
|
||||
}
|
||||
|
||||
// Build config for resumed session
|
||||
// Start with minimal required fields
|
||||
// Start by inheriting ALL configuration from parent session
|
||||
config := claudecode.SessionConfig{
|
||||
Query: req.Query,
|
||||
SessionID: parentSession.ClaudeSessionID, // This triggers --resume flag
|
||||
OutputFormat: claudecode.OutputStreamJSON, // Always use streaming JSON
|
||||
// Inherit Model and WorkingDir from parent session for database storage
|
||||
Model: claudecode.Model(parentSession.Model),
|
||||
WorkingDir: parentSession.WorkingDir,
|
||||
Query: req.Query,
|
||||
SessionID: parentSession.ClaudeSessionID, // This triggers --resume flag
|
||||
OutputFormat: claudecode.OutputStreamJSON, // Always use streaming JSON
|
||||
Model: claudecode.Model(parentSession.Model),
|
||||
WorkingDir: parentSession.WorkingDir,
|
||||
SystemPrompt: parentSession.SystemPrompt,
|
||||
AppendSystemPrompt: parentSession.AppendSystemPrompt,
|
||||
CustomInstructions: parentSession.CustomInstructions,
|
||||
PermissionPromptTool: parentSession.PermissionPromptTool,
|
||||
// MaxTurns intentionally NOT inherited - let it default or be specified
|
||||
}
|
||||
|
||||
// Apply optional overrides
|
||||
// Deserialize JSON arrays for tools
|
||||
if parentSession.AllowedTools != "" {
|
||||
var allowedTools []string
|
||||
if err := json.Unmarshal([]byte(parentSession.AllowedTools), &allowedTools); err == nil {
|
||||
config.AllowedTools = allowedTools
|
||||
}
|
||||
}
|
||||
if parentSession.DisallowedTools != "" {
|
||||
var disallowedTools []string
|
||||
if err := json.Unmarshal([]byte(parentSession.DisallowedTools), &disallowedTools); err == nil {
|
||||
config.DisallowedTools = disallowedTools
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve and inherit MCP configuration from parent session
|
||||
mcpServers, err := m.store.GetMCPServers(ctx, req.ParentSessionID)
|
||||
if err == nil && len(mcpServers) > 0 {
|
||||
config.MCPConfig = &claudecode.MCPConfig{
|
||||
MCPServers: make(map[string]claudecode.MCPServer),
|
||||
}
|
||||
for _, server := range mcpServers {
|
||||
var args []string
|
||||
var env map[string]string
|
||||
if err := json.Unmarshal([]byte(server.ArgsJSON), &args); err != nil {
|
||||
slog.Warn("failed to unmarshal MCP server args", "error", err, "server", server.Name)
|
||||
args = []string{}
|
||||
}
|
||||
if err := json.Unmarshal([]byte(server.EnvJSON), &env); err != nil {
|
||||
slog.Warn("failed to unmarshal MCP server env", "error", err, "server", server.Name)
|
||||
env = map[string]string{}
|
||||
}
|
||||
|
||||
config.MCPConfig.MCPServers[server.Name] = claudecode.MCPServer{
|
||||
Command: server.Command,
|
||||
Args: args,
|
||||
Env: env,
|
||||
}
|
||||
}
|
||||
slog.Debug("inherited MCP servers from parent session",
|
||||
"parent_session_id", req.ParentSessionID,
|
||||
"mcp_server_count", len(mcpServers))
|
||||
}
|
||||
|
||||
// Apply optional overrides (only if explicitly provided)
|
||||
if req.SystemPrompt != "" {
|
||||
config.SystemPrompt = req.SystemPrompt
|
||||
}
|
||||
|
||||
@@ -279,6 +279,9 @@ func TestContinueSession_CreatesNewSessionWithParentReference(t *testing.T) {
|
||||
}
|
||||
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
|
||||
|
||||
// Expect GetMCPServers call (even if it returns empty)
|
||||
mockStore.EXPECT().GetMCPServers(gomock.Any(), "parent-1").Return([]store.MCPServer{}, nil)
|
||||
|
||||
// Expect session creation with parent reference
|
||||
mockStore.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx interface{}, session *store.Session) error {
|
||||
@@ -299,6 +302,9 @@ func TestContinueSession_CreatesNewSessionWithParentReference(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Expect MCP servers to be stored (may or may not be called)
|
||||
mockStore.EXPECT().StoreMCPServers(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
// Expect status update to running (we can't test the full flow without mocking Claude client)
|
||||
// May be called twice if Claude fails to launch in background
|
||||
mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
@@ -348,6 +354,9 @@ func TestContinueSession_HandlesOptionalOverrides(t *testing.T) {
|
||||
}
|
||||
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
|
||||
|
||||
// Expect GetMCPServers call (even if it returns empty)
|
||||
mockStore.EXPECT().GetMCPServers(gomock.Any(), "parent-1").Return([]store.MCPServer{}, nil)
|
||||
|
||||
// Test with various overrides
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: "parent-1",
|
||||
@@ -377,6 +386,9 @@ func TestContinueSession_HandlesOptionalOverrides(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Expect MCP servers to be stored (if MCPConfig override is provided)
|
||||
mockStore.EXPECT().StoreMCPServers(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
// Expect status update
|
||||
// May be called twice if Claude fails to launch in background
|
||||
mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
|
||||
@@ -53,6 +54,12 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
|
||||
return nil, fmt.Errorf("failed to initialize schema: %w", err)
|
||||
}
|
||||
|
||||
// Apply migrations (this must be called AFTER initSchema for both new and existing databases)
|
||||
if err := store.applyMigrations(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("failed to apply migrations: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("SQLite store initialized", "path", dbPath)
|
||||
return store, nil
|
||||
}
|
||||
@@ -74,7 +81,11 @@ func (s *SQLiteStore) initSchema() error {
|
||||
working_dir TEXT,
|
||||
max_turns INTEGER,
|
||||
system_prompt TEXT,
|
||||
append_system_prompt TEXT,
|
||||
custom_instructions TEXT,
|
||||
permission_prompt_tool TEXT,
|
||||
allowed_tools TEXT,
|
||||
disallowed_tools TEXT,
|
||||
|
||||
-- Runtime status
|
||||
status TEXT NOT NULL DEFAULT 'starting',
|
||||
@@ -168,14 +179,66 @@ func (s *SQLiteStore) initSchema() error {
|
||||
return fmt.Errorf("failed to create schema: %w", err)
|
||||
}
|
||||
|
||||
// Record schema version
|
||||
// Record initial schema version
|
||||
// For new databases, we start at version 3 since the schema includes all fields
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO schema_version (version, description)
|
||||
VALUES (1, 'Initial schema with conversation events')
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark new databases as having all migrations applied
|
||||
_, err = s.db.Exec(`
|
||||
INSERT OR IGNORE INTO schema_version (version, description)
|
||||
VALUES (3, 'Initial schema includes all permission and tool fields')
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// applyMigrations applies any pending database migrations
|
||||
func (s *SQLiteStore) applyMigrations() error {
|
||||
// Get current schema version
|
||||
var currentVersion int
|
||||
err := s.db.QueryRow("SELECT MAX(version) FROM schema_version").Scan(¤tVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current schema version: %w", err)
|
||||
}
|
||||
|
||||
// Migration 2: Added constraint to ensure only resumable sessions can be parent sessions
|
||||
// (This migration already exists in production databases)
|
||||
|
||||
// Migration 3: Add missing permission and tool fields
|
||||
if currentVersion < 3 {
|
||||
slog.Info("Applying migration 3: Add permission and tool fields")
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
-- Add missing columns to sessions table
|
||||
ALTER TABLE sessions ADD COLUMN permission_prompt_tool TEXT;
|
||||
ALTER TABLE sessions ADD COLUMN append_system_prompt TEXT;
|
||||
ALTER TABLE sessions ADD COLUMN allowed_tools TEXT;
|
||||
ALTER TABLE sessions ADD COLUMN disallowed_tools TEXT;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply migration 3: %w", err)
|
||||
}
|
||||
|
||||
// Record migration
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO schema_version (version, description)
|
||||
VALUES (3, 'Add permission_prompt_tool, append_system_prompt, allowed_tools, disallowed_tools fields')
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record migration 3: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("Migration 3 applied successfully")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (s *SQLiteStore) Close() error {
|
||||
return s.db.Close()
|
||||
@@ -186,15 +249,17 @@ func (s *SQLiteStore) CreateSession(ctx context.Context, session *Session) error
|
||||
query := `
|
||||
INSERT INTO sessions (
|
||||
id, run_id, claude_session_id, parent_session_id,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
|
||||
permission_prompt_tool, allowed_tools, disallowed_tools,
|
||||
status, created_at, last_activity_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query,
|
||||
session.ID, session.RunID, session.ClaudeSessionID, session.ParentSessionID,
|
||||
session.Query, session.Summary, session.Model, session.WorkingDir, session.MaxTurns,
|
||||
session.SystemPrompt, session.CustomInstructions,
|
||||
session.SystemPrompt, session.AppendSystemPrompt, session.CustomInstructions,
|
||||
session.PermissionPromptTool, session.AllowedTools, session.DisallowedTools,
|
||||
session.Status, session.CreatedAt, session.LastActivityAt,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -284,14 +349,16 @@ func (s *SQLiteStore) UpdateSession(ctx context.Context, sessionID string, updat
|
||||
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||
query := `
|
||||
SELECT id, run_id, claude_session_id, parent_session_id,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
|
||||
permission_prompt_tool, allowed_tools, disallowed_tools,
|
||||
status, created_at, last_activity_at, completed_at,
|
||||
cost_usd, total_tokens, duration_ms, num_turns, result_content, error_message
|
||||
FROM sessions WHERE id = ?
|
||||
`
|
||||
|
||||
var session Session
|
||||
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, customInstructions sql.NullString
|
||||
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, appendSystemPrompt, customInstructions sql.NullString
|
||||
var permissionPromptTool, allowedTools, disallowedTools sql.NullString
|
||||
var completedAt sql.NullTime
|
||||
var costUSD sql.NullFloat64
|
||||
var totalTokens, durationMS, numTurns sql.NullInt64
|
||||
@@ -300,7 +367,8 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
err := s.db.QueryRowContext(ctx, query, sessionID).Scan(
|
||||
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
|
||||
&session.Query, &summary, &model, &workingDir, &session.MaxTurns,
|
||||
&systemPrompt, &customInstructions,
|
||||
&systemPrompt, &appendSystemPrompt, &customInstructions,
|
||||
&permissionPromptTool, &allowedTools, &disallowedTools,
|
||||
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
|
||||
&costUSD, &totalTokens, &durationMS, &numTurns, &resultContent, &errorMessage,
|
||||
)
|
||||
@@ -318,7 +386,11 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
session.Model = model.String
|
||||
session.WorkingDir = workingDir.String
|
||||
session.SystemPrompt = systemPrompt.String
|
||||
session.AppendSystemPrompt = appendSystemPrompt.String
|
||||
session.CustomInstructions = customInstructions.String
|
||||
session.PermissionPromptTool = permissionPromptTool.String
|
||||
session.AllowedTools = allowedTools.String
|
||||
session.DisallowedTools = disallowedTools.String
|
||||
session.ResultContent = resultContent.String
|
||||
session.ErrorMessage = errorMessage.String
|
||||
if completedAt.Valid {
|
||||
@@ -347,7 +419,8 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Session, error) {
|
||||
query := `
|
||||
SELECT id, run_id, claude_session_id, parent_session_id,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
|
||||
permission_prompt_tool, allowed_tools, disallowed_tools,
|
||||
status, created_at, last_activity_at, completed_at,
|
||||
cost_usd, total_tokens, duration_ms, num_turns, result_content, error_message
|
||||
FROM sessions
|
||||
@@ -355,7 +428,8 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
|
||||
`
|
||||
|
||||
var session Session
|
||||
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, customInstructions sql.NullString
|
||||
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, appendSystemPrompt, customInstructions sql.NullString
|
||||
var permissionPromptTool, allowedTools, disallowedTools sql.NullString
|
||||
var completedAt sql.NullTime
|
||||
var costUSD sql.NullFloat64
|
||||
var totalTokens, durationMS, numTurns sql.NullInt64
|
||||
@@ -364,7 +438,8 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
|
||||
err := s.db.QueryRowContext(ctx, query, runID).Scan(
|
||||
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
|
||||
&session.Query, &summary, &model, &workingDir, &session.MaxTurns,
|
||||
&systemPrompt, &customInstructions,
|
||||
&systemPrompt, &appendSystemPrompt, &customInstructions,
|
||||
&permissionPromptTool, &allowedTools, &disallowedTools,
|
||||
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
|
||||
&costUSD, &totalTokens, &durationMS, &numTurns, &resultContent, &errorMessage,
|
||||
)
|
||||
@@ -382,7 +457,11 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
|
||||
session.Model = model.String
|
||||
session.WorkingDir = workingDir.String
|
||||
session.SystemPrompt = systemPrompt.String
|
||||
session.AppendSystemPrompt = appendSystemPrompt.String
|
||||
session.CustomInstructions = customInstructions.String
|
||||
session.PermissionPromptTool = permissionPromptTool.String
|
||||
session.AllowedTools = allowedTools.String
|
||||
session.DisallowedTools = disallowedTools.String
|
||||
session.ResultContent = resultContent.String
|
||||
session.ErrorMessage = errorMessage.String
|
||||
if completedAt.Valid {
|
||||
@@ -411,7 +490,8 @@ func (s *SQLiteStore) GetSessionByRunID(ctx context.Context, runID string) (*Ses
|
||||
func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
|
||||
query := `
|
||||
SELECT id, run_id, claude_session_id, parent_session_id,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, custom_instructions,
|
||||
query, summary, model, working_dir, max_turns, system_prompt, append_system_prompt, custom_instructions,
|
||||
permission_prompt_tool, allowed_tools, disallowed_tools,
|
||||
status, created_at, last_activity_at, completed_at,
|
||||
cost_usd, total_tokens, duration_ms, num_turns, result_content, error_message
|
||||
FROM sessions
|
||||
@@ -427,7 +507,8 @@ func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
|
||||
var sessions []*Session
|
||||
for rows.Next() {
|
||||
var session Session
|
||||
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, customInstructions sql.NullString
|
||||
var claudeSessionID, parentSessionID, summary, model, workingDir, systemPrompt, appendSystemPrompt, customInstructions sql.NullString
|
||||
var permissionPromptTool, allowedTools, disallowedTools sql.NullString
|
||||
var completedAt sql.NullTime
|
||||
var costUSD sql.NullFloat64
|
||||
var totalTokens, durationMS, numTurns sql.NullInt64
|
||||
@@ -436,7 +517,8 @@ func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
|
||||
err := rows.Scan(
|
||||
&session.ID, &session.RunID, &claudeSessionID, &parentSessionID,
|
||||
&session.Query, &summary, &model, &workingDir, &session.MaxTurns,
|
||||
&systemPrompt, &customInstructions,
|
||||
&systemPrompt, &appendSystemPrompt, &customInstructions,
|
||||
&permissionPromptTool, &allowedTools, &disallowedTools,
|
||||
&session.Status, &session.CreatedAt, &session.LastActivityAt, &completedAt,
|
||||
&costUSD, &totalTokens, &durationMS, &numTurns, &resultContent, &errorMessage,
|
||||
)
|
||||
@@ -451,7 +533,11 @@ func (s *SQLiteStore) ListSessions(ctx context.Context) ([]*Session, error) {
|
||||
session.Model = model.String
|
||||
session.WorkingDir = workingDir.String
|
||||
session.SystemPrompt = systemPrompt.String
|
||||
session.AppendSystemPrompt = appendSystemPrompt.String
|
||||
session.CustomInstructions = customInstructions.String
|
||||
session.PermissionPromptTool = permissionPromptTool.String
|
||||
session.AllowedTools = allowedTools.String
|
||||
session.DisallowedTools = disallowedTools.String
|
||||
session.ResultContent = resultContent.String
|
||||
session.ErrorMessage = errorMessage.String
|
||||
if completedAt.Valid {
|
||||
@@ -573,6 +659,7 @@ func (s *SQLiteStore) GetSessionConversation(ctx context.Context, sessionID stri
|
||||
// Walk up the parent chain to get all related claude session IDs
|
||||
claudeSessionIDs := []string{}
|
||||
currentID := sessionID
|
||||
isFirstSession := true
|
||||
|
||||
for currentID != "" {
|
||||
var claudeSessionID sql.NullString
|
||||
@@ -584,10 +671,16 @@ func (s *SQLiteStore) GetSessionConversation(ctx context.Context, sessionID stri
|
||||
).Scan(&claudeSessionID, &parentID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
break // Session not found, stop walking
|
||||
// If the requested session doesn't exist, return error
|
||||
if isFirstSession {
|
||||
return nil, fmt.Errorf("session not found: %s", sessionID)
|
||||
}
|
||||
// Otherwise, parent not found, just stop walking
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
isFirstSession = false
|
||||
|
||||
// Add claude session ID if present (in reverse order for chronological events)
|
||||
if claudeSessionID.Valid && claudeSessionID.String != "" {
|
||||
@@ -892,6 +985,21 @@ func (s *SQLiteStore) CorrelateApprovalByToolID(ctx context.Context, sessionID s
|
||||
|
||||
// UpdateApprovalStatus updates the status of an approval
|
||||
func (s *SQLiteStore) UpdateApprovalStatus(ctx context.Context, approvalID string, status string) error {
|
||||
// Special handling for resolved status - don't overwrite approved/denied
|
||||
if status == ApprovalStatusResolved {
|
||||
query := `
|
||||
UPDATE conversation_events
|
||||
SET approval_status = ?
|
||||
WHERE approval_id = ? AND approval_status = ?
|
||||
`
|
||||
_, err := s.db.ExecContext(ctx, query, status, approvalID, ApprovalStatusPending)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update approval status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For approved/denied, always update
|
||||
query := `
|
||||
UPDATE conversation_events
|
||||
SET approval_status = ?
|
||||
@@ -935,6 +1043,7 @@ func (s *SQLiteStore) GetMCPServers(ctx context.Context, sessionID string) ([]MC
|
||||
SELECT id, session_id, name, command, args_json, env_json
|
||||
FROM mcp_servers
|
||||
WHERE session_id = ?
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, sessionID)
|
||||
@@ -975,8 +1084,17 @@ func (s *SQLiteStore) StoreRawEvent(ctx context.Context, sessionID string, event
|
||||
|
||||
// Helper function to convert MCP config to store format
|
||||
func MCPServersFromConfig(sessionID string, config map[string]claudecode.MCPServer) ([]MCPServer, error) {
|
||||
// First, collect all server names and sort them for deterministic ordering
|
||||
names := make([]string, 0, len(config))
|
||||
for name := range config {
|
||||
names = append(names, name)
|
||||
}
|
||||
// Sort names to ensure consistent ordering
|
||||
sort.Strings(names)
|
||||
|
||||
servers := make([]MCPServer, 0, len(config))
|
||||
for name, server := range config {
|
||||
for _, name := range names {
|
||||
server := config[name]
|
||||
argsJSON, err := json.Marshal(server.Args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal args: %w", err)
|
||||
|
||||
@@ -436,10 +436,10 @@ func TestGetSessionConversationWithParentChain(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("GetSessionConversation_NonExistentSession", func(t *testing.T) {
|
||||
// Should return empty events for non-existent session
|
||||
events, err := store.GetSessionConversation(ctx, "does-not-exist")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 0)
|
||||
// Should return error for non-existent session
|
||||
_, err := store.GetSessionConversation(ctx, "does-not-exist")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "session not found: does-not-exist")
|
||||
})
|
||||
|
||||
t.Run("GetSessionConversation_NoParent", func(t *testing.T) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
|
||||
@@ -42,27 +43,31 @@ type ConversationStore interface {
|
||||
|
||||
// Session represents a Claude Code session
|
||||
type Session struct {
|
||||
ID string
|
||||
RunID string
|
||||
ClaudeSessionID string
|
||||
ParentSessionID string
|
||||
Query string
|
||||
Summary string
|
||||
Model string
|
||||
WorkingDir string
|
||||
MaxTurns int
|
||||
SystemPrompt string
|
||||
CustomInstructions string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
LastActivityAt time.Time
|
||||
CompletedAt *time.Time
|
||||
CostUSD *float64
|
||||
TotalTokens *int
|
||||
DurationMS *int
|
||||
NumTurns *int
|
||||
ResultContent string
|
||||
ErrorMessage string
|
||||
ID string
|
||||
RunID string
|
||||
ClaudeSessionID string
|
||||
ParentSessionID string
|
||||
Query string
|
||||
Summary string
|
||||
Model string
|
||||
WorkingDir string
|
||||
MaxTurns int
|
||||
SystemPrompt string
|
||||
AppendSystemPrompt string // NEW: Append to system prompt
|
||||
CustomInstructions string
|
||||
PermissionPromptTool string // NEW: MCP tool for permission prompts
|
||||
AllowedTools string // NEW: JSON array of allowed tools
|
||||
DisallowedTools string // NEW: JSON array of disallowed tools
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
LastActivityAt time.Time
|
||||
CompletedAt *time.Time
|
||||
CostUSD *float64
|
||||
TotalTokens *int
|
||||
DurationMS *int
|
||||
NumTurns *int
|
||||
ResultContent string
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
// SessionUpdate contains fields that can be updated
|
||||
@@ -148,17 +153,25 @@ const (
|
||||
|
||||
// NewSessionFromConfig creates a Session from Claude SessionConfig
|
||||
func NewSessionFromConfig(id, runID string, config claudecode.SessionConfig) *Session {
|
||||
// Convert slices to JSON for storage
|
||||
allowedToolsJSON, _ := json.Marshal(config.AllowedTools)
|
||||
disallowedToolsJSON, _ := json.Marshal(config.DisallowedTools)
|
||||
|
||||
return &Session{
|
||||
ID: id,
|
||||
RunID: runID,
|
||||
Query: config.Query,
|
||||
Model: string(config.Model),
|
||||
WorkingDir: config.WorkingDir,
|
||||
MaxTurns: config.MaxTurns,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
CustomInstructions: config.CustomInstructions,
|
||||
Status: SessionStatusStarting,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
ID: id,
|
||||
RunID: runID,
|
||||
Query: config.Query,
|
||||
Model: string(config.Model),
|
||||
WorkingDir: config.WorkingDir,
|
||||
MaxTurns: config.MaxTurns,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
AppendSystemPrompt: config.AppendSystemPrompt,
|
||||
CustomInstructions: config.CustomInstructions,
|
||||
PermissionPromptTool: config.PermissionPromptTool,
|
||||
AllowedTools: string(allowedToolsJSON),
|
||||
DisallowedTools: string(disallowedToolsJSON),
|
||||
Status: SessionStatusStarting,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,10 +32,6 @@ type conversationModel struct {
|
||||
resumeInput textinput.Model
|
||||
showResumePrompt bool
|
||||
|
||||
// Parent session data for inheritance (stored during resume)
|
||||
parentModel string
|
||||
parentWorkingDir string
|
||||
|
||||
// Loading states
|
||||
loading bool
|
||||
error error
|
||||
@@ -85,9 +81,6 @@ func (cm *conversationModel) setSession(sessionID string) {
|
||||
cm.clearApprovalState()
|
||||
cm.clearResumeState()
|
||||
cm.stopPolling() // Stop any existing polling
|
||||
// Clear parent data when switching sessions
|
||||
cm.parentModel = ""
|
||||
cm.parentWorkingDir = ""
|
||||
// Reset scroll tracking
|
||||
cm.wasAtBottom = true
|
||||
}
|
||||
@@ -217,16 +210,6 @@ func (cm *conversationModel) Update(msg tea.Msg, m *model) tea.Cmd {
|
||||
cm.events = msg.events
|
||||
cm.lastRefresh = time.Now()
|
||||
|
||||
// If this is a child session with missing data, use parent data stored during resume
|
||||
if cm.session != nil && cm.session.ParentSessionID != "" {
|
||||
if cm.session.Model == "" && cm.parentModel != "" {
|
||||
cm.session.Model = cm.parentModel
|
||||
}
|
||||
if cm.session.WorkingDir == "" && cm.parentWorkingDir != "" {
|
||||
cm.session.WorkingDir = cm.parentWorkingDir
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the conversation for future use (if session and events are not nil)
|
||||
if cm.session != nil && cm.events != nil {
|
||||
m.conversationCache.put(cm.sessionID, cm.session, cm.events)
|
||||
@@ -390,9 +373,6 @@ func (cm *conversationModel) updateResumeInput(msg tea.KeyMsg, m *model) tea.Cmd
|
||||
if cm.session != nil && cm.sessionID != "" {
|
||||
query := cm.resumeInput.Value()
|
||||
if query != "" {
|
||||
// Store parent session data for inheritance
|
||||
cm.parentModel = cm.session.Model
|
||||
cm.parentWorkingDir = cm.session.WorkingDir
|
||||
return continueSession(m.daemonClient, cm.sessionID, query)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user