From b8ea3ee9a70901d0c509a04666743d1ed63198ad Mon Sep 17 00:00:00 2001 From: Allison Durham Date: Mon, 9 Jun 2025 14:49:09 -0700 Subject: [PATCH] Add multi-turn conversation support with session continuation (#207) * resuming session work * pull request template formatting * make claudecode-go happy with race conditions * tests for session continuation --- .github/PULL_REQUEST_TEMPLATE.md | 43 +- claudecode-go/client.go | 31 +- claudecode-go/types.go | 22 +- hld/client/client.go | 9 + hld/client/types.go | 3 + ...aemon_continue_session_integration_test.go | 515 ++++++++++++++++++ hld/rpc/handlers.go | 55 +- hld/rpc/handlers_continue_session_test.go | 248 +++++++++ hld/rpc/types.go | 23 + hld/session/manager.go | 140 +++++ hld/session/manager_test.go | 278 ++++++++++ hld/session/types.go | 17 + hld/store/sqlite.go | 99 +++- hld/store/sqlite_test.go | 200 +++++++ 14 files changed, 1619 insertions(+), 64 deletions(-) create mode 100644 hld/daemon/daemon_continue_session_integration_test.go create mode 100644 hld/rpc/handlers_continue_session_test.go diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0e4ce8e..dfc23b4 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,44 +1,13 @@ - - -**What I did** - - - -**How I did it** - - +## How to verify it - [ ] I have ensured `make check test` passes -**How to verify it** +## Description for the changelog - - -**Description for the changelog** - - - - +## A picture of a cute animal (not mandatory but encouraged) diff --git a/claudecode-go/client.go b/claudecode-go/client.go index aaa3a9d..a0d2036 100644 --- a/claudecode-go/client.go +++ b/claudecode-go/client.go @@ -185,7 +185,7 @@ func (c *Client) Launch(config SessionConfig) (*Session, error) { // Wait for process to complete in background go func() { // Wait for the command to exit - session.err = cmd.Wait() + session.SetError(cmd.Wait()) // IMPORTANT: Wait for parsing to complete before signaling done. // This ensures that all output has been read and processed before @@ -213,8 +213,8 @@ func (c *Client) LaunchAndWait(config SessionConfig) (*Result, error) { func (s *Session) Wait() (*Result, error) { <-s.done - if s.err != nil && s.result == nil { - return nil, fmt.Errorf("claude process failed: %w", s.err) + if err := s.Error(); err != nil && s.result == nil { + return nil, fmt.Errorf("claude process failed: %w", err) } return s.result, nil @@ -232,9 +232,11 @@ func (s *Session) Kill() error { func (s *Session) parseStreamingJSON(stdout, stderr io.Reader) { scanner := bufio.NewScanner(stdout) var stderrBuf strings.Builder + stderrDone := make(chan struct{}) // Capture stderr in background go func() { + defer close(stderrDone) buf := make([]byte, 1024) for { n, err := stderr.Read(buf) @@ -283,9 +285,12 @@ func (s *Session) parseStreamingJSON(stdout, stderr io.Reader) { s.Events <- event } + // Wait for stderr reading to complete before accessing the buffer + <-stderrDone + // If we got stderr output, that's an error if stderrOutput := stderrBuf.String(); stderrOutput != "" { - s.err = fmt.Errorf("claude error: %s", stderrOutput) + s.SetError(fmt.Errorf("claude error: %s", stderrOutput)) } // Close events channel when done parsing @@ -296,7 +301,7 @@ func (s *Session) parseStreamingJSON(stdout, stderr io.Reader) { func (s *Session) parseSingleJSON(stdout, stderr io.Reader) { defer func() { if r := recover(); r != nil { - s.err = fmt.Errorf("panic in parseSingleJSON: %v", r) + s.SetError(fmt.Errorf("panic in parseSingleJSON: %v", r)) } }() @@ -304,26 +309,26 @@ func (s *Session) parseSingleJSON(stdout, stderr io.Reader) { // Read all stdout if _, err := io.Copy(&stdoutBuf, stdout); err != nil { - s.err = fmt.Errorf("failed to read stdout: %w", err) + s.SetError(fmt.Errorf("failed to read stdout: %w", err)) return } // Read all stderr if _, err := io.Copy(&stderrBuf, stderr); err != nil { - s.err = fmt.Errorf("failed to read stderr: %w", err) + s.SetError(fmt.Errorf("failed to read stderr: %w", err)) return } // Parse JSON result output := stdoutBuf.String() if output == "" { - s.err = fmt.Errorf("no output from claude") + s.SetError(fmt.Errorf("no output from claude")) return } var result Result if err := json.Unmarshal([]byte(output), &result); err != nil { - s.err = fmt.Errorf("failed to parse JSON output: %w\nOutput was: %s", err, output) + s.SetError(fmt.Errorf("failed to parse JSON output: %w\nOutput was: %s", err, output)) return } s.result = &result @@ -333,7 +338,7 @@ func (s *Session) parseSingleJSON(stdout, stderr io.Reader) { if stderrOutput := stderrBuf.String(); stderrOutput != "" { // Don't override result if we got valid JSON if s.result == nil { - s.err = fmt.Errorf("claude error: %s", stderrOutput) + s.SetError(fmt.Errorf("claude error: %s", stderrOutput)) } } } @@ -344,13 +349,13 @@ func (s *Session) parseTextOutput(stdout, stderr io.Reader) { // Read all stdout if _, err := io.Copy(&stdoutBuf, stdout); err != nil { - s.err = fmt.Errorf("failed to read stdout: %w", err) + s.SetError(fmt.Errorf("failed to read stdout: %w", err)) return } // Read all stderr if _, err := io.Copy(&stderrBuf, stderr); err != nil { - s.err = fmt.Errorf("failed to read stderr: %w", err) + s.SetError(fmt.Errorf("failed to read stderr: %w", err)) return } @@ -365,6 +370,6 @@ func (s *Session) parseTextOutput(stdout, stderr io.Reader) { // If we got stderr output, that's an error if stderrOutput := stderrBuf.String(); stderrOutput != "" { - s.err = fmt.Errorf("claude error: %s", stderrOutput) + s.SetError(fmt.Errorf("claude error: %s", stderrOutput)) } } diff --git a/claudecode-go/types.go b/claudecode-go/types.go index ed3de10..9a24bed 100644 --- a/claudecode-go/types.go +++ b/claudecode-go/types.go @@ -2,6 +2,7 @@ package claudecode import ( "os/exec" + "sync" "time" ) @@ -140,5 +141,24 @@ type Session struct { cmd *exec.Cmd done chan struct{} result *Result - err error + + // Thread-safe error handling + mu sync.RWMutex + err error +} + +// SetError safely sets the error +func (s *Session) SetError(err error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.err == nil { + s.err = err + } +} + +// Error safely gets the error +func (s *Session) Error() error { + s.mu.RLock() + defer s.mu.RUnlock() + return s.err } diff --git a/hld/client/client.go b/hld/client/client.go index 1b50441..ac991bb 100644 --- a/hld/client/client.go +++ b/hld/client/client.go @@ -256,6 +256,15 @@ func (c *client) ListSessions() (*rpc.ListSessionsResponse, error) { return &resp, nil } +// ContinueSession continues an existing completed session with a new query +func (c *client) ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueSessionResponse, error) { + var resp rpc.ContinueSessionResponse + if err := c.call("continueSession", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + // FetchApprovals fetches pending approvals from the daemon func (c *client) FetchApprovals(sessionID string) ([]approval.PendingApproval, error) { req := rpc.FetchApprovalsRequest{ diff --git a/hld/client/types.go b/hld/client/types.go index 0d12864..8d44153 100644 --- a/hld/client/types.go +++ b/hld/client/types.go @@ -18,6 +18,9 @@ type Client interface { // ListSessions lists all active sessions ListSessions() (*rpc.ListSessionsResponse, error) + // ContinueSession continues an existing completed session with a new query + ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueSessionResponse, error) + // FetchApprovals fetches pending approvals from the daemon FetchApprovals(sessionID string) ([]approval.PendingApproval, error) diff --git a/hld/daemon/daemon_continue_session_integration_test.go b/hld/daemon/daemon_continue_session_integration_test.go new file mode 100644 index 0000000..beaca07 --- /dev/null +++ b/hld/daemon/daemon_continue_session_integration_test.go @@ -0,0 +1,515 @@ +package daemon + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "net" + "testing" + "time" + + "github.com/humanlayer/humanlayer/hld/bus" + "github.com/humanlayer/humanlayer/hld/config" + "github.com/humanlayer/humanlayer/hld/internal/testutil" + "github.com/humanlayer/humanlayer/hld/rpc" + "github.com/humanlayer/humanlayer/hld/session" + "github.com/humanlayer/humanlayer/hld/store" +) + +func TestIntegrationContinueSession(t *testing.T) { + // Use test-specific socket path + socketPath := testutil.SocketPath(t, "continue-session") + + // Create daemon components + eventBus := bus.NewEventBus() + sqliteStore, err := store.NewSQLiteStore(":memory:") + if err != nil { + t.Fatalf("Failed to create store: %v", err) + } + defer func() { _ = sqliteStore.Close() }() + + sessionManager, err := session.NewManager(eventBus, sqliteStore) + if err != nil { + t.Fatalf("Failed to create session manager: %v", err) + } + + // Create daemon + d := &Daemon{ + socketPath: socketPath, + config: &config.Config{SocketPath: socketPath, DatabasePath: ":memory:"}, + eventBus: eventBus, + store: sqliteStore, + sessions: sessionManager, + rpcServer: rpc.NewServer(), + } + + // Register RPC handlers + sessionHandlers := rpc.NewSessionHandlers(sessionManager, sqliteStore) + sessionHandlers.Register(d.rpcServer) + + // Start daemon + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + if err := d.Run(ctx); err != nil { + t.Logf("daemon run error: %v", err) + } + }() + + // Wait for daemon to be ready + time.Sleep(200 * time.Millisecond) + + // Create helper function to send RPC requests + sendRPC := func(t *testing.T, method string, params interface{}) (json.RawMessage, error) { + conn, err := net.Dial("unix", socketPath) + if err != nil { + t.Fatalf("failed to connect to daemon: %v", err) + } + defer func() { _ = conn.Close() }() + + request := map[string]interface{}{ + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": 1, + } + + data, err := json.Marshal(request) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + + if _, err := conn.Write(append(data, '\n')); err != nil { + t.Fatalf("failed to write request: %v", err) + } + + scanner := bufio.NewScanner(conn) + if !scanner.Scan() { + t.Fatal("no response received") + } + + var response map[string]interface{} + if err := json.Unmarshal(scanner.Bytes(), &response); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if errObj, ok := response["error"]; ok { + if errMap, ok := errObj.(map[string]interface{}); ok { + if msg, ok := errMap["message"].(string); ok { + return nil, fmt.Errorf("%s", msg) + } + } + return nil, fmt.Errorf("RPC error: %v", errObj) + } + + if result, ok := response["result"]; ok { + resultBytes, err := json.Marshal(result) + if err != nil { + t.Fatalf("failed to marshal result: %v", err) + } + return resultBytes, nil + } + + return nil, fmt.Errorf("no result in response") + } + + t.Run("ContinueSession_RequiresCompletedParent", func(t *testing.T) { + // Create a parent session that's still running + parentSessionID := "parent-running" + parentSession := &store.Session{ + ID: parentSessionID, + RunID: "run-parent", + ClaudeSessionID: "claude-parent", + Status: store.SessionStatusRunning, // Not completed + Query: "original query", + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + } + + // Insert parent session directly into database + if err := d.store.CreateSession(ctx, parentSession); err != nil { + t.Fatalf("Failed to create parent session: %v", err) + } + + // Try to continue the running session + req := rpc.ContinueSessionRequest{ + SessionID: parentSessionID, + Query: "continue this", + } + + _, err := sendRPC(t, "continueSession", req) + if err == nil { + t.Error("Expected error when continuing running session") + } + if err.Error() != "cannot continue session with status running (must be completed)" { + t.Errorf("Unexpected error: %v", err) + } + }) + + t.Run("ContinueSession_RequiresClaudeSessionID", func(t *testing.T) { + // Create a parent session without claude_session_id + parentSessionID := "parent-no-claude" + parentSession := &store.Session{ + ID: parentSessionID, + RunID: "run-no-claude", + ClaudeSessionID: "", // Missing + Status: store.SessionStatusCompleted, + Query: "original query", + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + CompletedAt: &time.Time{}, + } + + // Insert parent session + if err := d.store.CreateSession(ctx, parentSession); err != nil { + t.Fatalf("Failed to create parent session: %v", err) + } + + // Try to continue without claude_session_id + req := rpc.ContinueSessionRequest{ + SessionID: parentSessionID, + Query: "continue this", + } + + _, err := sendRPC(t, "continueSession", req) + if err == nil { + t.Error("Expected error when continuing session without claude_session_id") + } + if err.Error() != "parent session missing claude_session_id (cannot resume)" { + t.Errorf("Unexpected error: %v", err) + } + }) + + t.Run("ContinueSession_CreatesChildSession", func(t *testing.T) { + // Create a valid completed parent session + parentSessionID := "parent-valid" + claudeSessionID := "claude-valid" + parentSession := &store.Session{ + ID: parentSessionID, + RunID: "run-valid", + ClaudeSessionID: claudeSessionID, + Status: store.SessionStatusCompleted, + Query: "original query", + Model: "claude-3-opus", + WorkingDir: "/test/dir", + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + CompletedAt: &time.Time{}, + } + + // Insert parent session + if err := d.store.CreateSession(ctx, parentSession); err != nil { + t.Fatalf("Failed to create parent session: %v", err) + } + + // Add some conversation history to parent + events := []*store.ConversationEvent{ + { + SessionID: parentSessionID, + ClaudeSessionID: claudeSessionID, + EventType: store.EventTypeMessage, + Role: "user", + Content: "original query", + }, + { + SessionID: parentSessionID, + ClaudeSessionID: claudeSessionID, + EventType: store.EventTypeMessage, + Role: "assistant", + Content: "Original response", + }, + } + + for _, event := range events { + if err := d.store.AddConversationEvent(ctx, event); err != nil { + t.Fatalf("Failed to add conversation event: %v", err) + } + } + + // Continue the session + req := rpc.ContinueSessionRequest{ + SessionID: parentSessionID, + Query: "follow up question", + SystemPrompt: "You are helpful", + CustomInstructions: "Be concise", + MaxTurns: 3, + } + + result, err := sendRPC(t, "continueSession", req) + if err != nil { + // Expected - Claude binary might not exist in test environment + if err.Error() != "failed to continue session: failed to launch resumed Claude session: failed to start claude: exec: \"claude\": executable file not found in $PATH" { + t.Errorf("Unexpected error: %v", err) + } + // Even if Claude fails to launch, we should have created the session + return + } + + // Parse response + var resp rpc.ContinueSessionResponse + if err := json.Unmarshal(result, &resp); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Verify response + if resp.SessionID == "" { + t.Error("Expected session ID in response") + } + if resp.RunID == "" { + t.Error("Expected run ID in response") + } + if resp.ParentSessionID != parentSessionID { + t.Errorf("Expected parent_session_id %s, got %s", parentSessionID, resp.ParentSessionID) + } + + // Verify the new session was created with parent reference + newSession, err := d.store.GetSession(ctx, resp.SessionID) + if err != nil { + t.Fatalf("Failed to get new session: %v", err) + } + + if newSession.ParentSessionID != parentSessionID { + t.Errorf("Expected parent_session_id %s, got %s", parentSessionID, newSession.ParentSessionID) + } + if newSession.Query != "follow up question" { + t.Errorf("Expected query 'follow up question', got %s", newSession.Query) + } + if newSession.SystemPrompt != "You are helpful" { + t.Errorf("Expected system prompt override, got %s", newSession.SystemPrompt) + } + if newSession.CustomInstructions != "Be concise" { + t.Errorf("Expected custom instructions override, got %s", newSession.CustomInstructions) + } + if newSession.MaxTurns != 3 { + t.Errorf("Expected max turns 3, got %d", newSession.MaxTurns) + } + }) + + t.Run("ContinueSession_HandlesOptionalMCPConfig", func(t *testing.T) { + // Create parent session + parentSessionID := "parent-mcp" + parentSession := &store.Session{ + ID: parentSessionID, + RunID: "run-mcp", + ClaudeSessionID: "claude-mcp", + Status: store.SessionStatusCompleted, + Query: "original", + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + CompletedAt: &time.Time{}, + } + + if err := d.store.CreateSession(ctx, parentSession); err != nil { + t.Fatalf("Failed to create parent session: %v", err) + } + + // Create MCP config + mcpConfig := map[string]interface{}{ + "mcpServers": map[string]interface{}{ + "test-server": map[string]interface{}{ + "command": "node", + "args": []string{"server.js"}, + "env": map[string]string{ + "TEST": "value", + }, + }, + }, + } + + mcpConfigJSON, err := json.Marshal(mcpConfig) + if err != nil { + t.Fatalf("Failed to marshal MCP config: %v", err) + } + + // Continue with MCP config + req := rpc.ContinueSessionRequest{ + SessionID: parentSessionID, + Query: "with mcp", + MCPConfig: string(mcpConfigJSON), + } + + _, err = sendRPC(t, "continueSession", req) + // Expected to fail (no Claude binary), but session should be created + if err != nil && !containsError(err, "failed to launch resumed Claude session") { + t.Errorf("Unexpected error: %v", err) + } + }) + + t.Run("GetConversation_IncludesParentHistory", func(t *testing.T) { + // Create a chain of sessions: grandparent -> parent -> child + grandparentID := "grandparent" + parentID := "parent-chain" + childID := "child" + + // Create grandparent session + grandparent := &store.Session{ + ID: grandparentID, + RunID: "run-gp", + ClaudeSessionID: "claude-gp", + Status: store.SessionStatusCompleted, + Query: "grandparent query", + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + CompletedAt: &time.Time{}, + } + if err := d.store.CreateSession(ctx, grandparent); err != nil { + t.Fatalf("Failed to create grandparent: %v", err) + } + + // Add grandparent events + gpEvents := []*store.ConversationEvent{ + { + SessionID: grandparentID, + ClaudeSessionID: "claude-gp", + EventType: store.EventTypeMessage, + Role: "user", + Content: "grandparent query", + }, + { + SessionID: grandparentID, + ClaudeSessionID: "claude-gp", + EventType: store.EventTypeMessage, + Role: "assistant", + Content: "grandparent response", + }, + } + for _, event := range gpEvents { + if err := d.store.AddConversationEvent(ctx, event); err != nil { + t.Fatalf("Failed to add grandparent event: %v", err) + } + } + + // Create parent session + parent := &store.Session{ + ID: parentID, + RunID: "run-p", + ClaudeSessionID: "claude-p", + ParentSessionID: grandparentID, + Status: store.SessionStatusCompleted, + Query: "parent query", + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + CompletedAt: &time.Time{}, + } + if err := d.store.CreateSession(ctx, parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // Add parent events + pEvents := []*store.ConversationEvent{ + { + SessionID: parentID, + ClaudeSessionID: "claude-p", + EventType: store.EventTypeMessage, + Role: "user", + Content: "parent query", + }, + { + SessionID: parentID, + ClaudeSessionID: "claude-p", + EventType: store.EventTypeMessage, + Role: "assistant", + Content: "parent response", + }, + } + for _, event := range pEvents { + if err := d.store.AddConversationEvent(ctx, event); err != nil { + t.Fatalf("Failed to add parent event: %v", err) + } + } + + // Create child session + child := &store.Session{ + ID: childID, + RunID: "run-c", + ClaudeSessionID: "claude-c", + ParentSessionID: parentID, + Status: store.SessionStatusCompleted, + Query: "child query", + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + CompletedAt: &time.Time{}, + } + if err := d.store.CreateSession(ctx, child); err != nil { + t.Fatalf("Failed to create child: %v", err) + } + + // Add child events + cEvents := []*store.ConversationEvent{ + { + SessionID: childID, + ClaudeSessionID: "claude-c", + EventType: store.EventTypeMessage, + Role: "user", + Content: "child query", + }, + { + SessionID: childID, + ClaudeSessionID: "claude-c", + EventType: store.EventTypeMessage, + Role: "assistant", + Content: "child response", + }, + } + for _, event := range cEvents { + if err := d.store.AddConversationEvent(ctx, event); err != nil { + t.Fatalf("Failed to add child event: %v", err) + } + } + + // Get conversation for child session - should include full history + req := rpc.GetConversationRequest{ + SessionID: childID, + } + result, err := sendRPC(t, "getConversation", req) + if err != nil { + t.Fatalf("Failed to get conversation: %v", err) + } + + // Parse response + var conversation rpc.GetConversationResponse + if err := json.Unmarshal(result, &conversation); err != nil { + t.Fatalf("Failed to unmarshal conversation: %v", err) + } + + // Verify we got all events in correct order + if len(conversation.Events) != 6 { + t.Errorf("Expected 6 events (2 from each session), got %d", len(conversation.Events)) + } + + // Verify chronological order + expectedContents := []string{ + "grandparent query", + "grandparent response", + "parent query", + "parent response", + "child query", + "child response", + } + + for i, event := range conversation.Events { + if i < len(expectedContents) && event.Content != expectedContents[i] { + t.Errorf("Event %d: expected content '%s', got '%s'", + i, expectedContents[i], event.Content) + } + } + }) +} + +func containsError(err error, substr string) bool { + if err == nil { + return false + } + return len(err.Error()) >= len(substr) && contains(err.Error(), substr) +} + +func contains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/hld/rpc/handlers.go b/hld/rpc/handlers.go index f138d00..2d36877 100644 --- a/hld/rpc/handlers.go +++ b/hld/rpc/handlers.go @@ -146,7 +146,7 @@ func (h *SessionHandlers) HandleGetConversation(ctx context.Context, params json // Get conversation by Claude session ID events, err = h.store.GetConversation(ctx, req.ClaudeSessionID) } else { - // Get conversation by session ID + // Get conversation by session ID - always returns full history including parents events, err = h.store.GetSessionConversation(ctx, req.SessionID) } @@ -205,6 +205,7 @@ func (h *SessionHandlers) HandleGetSessionState(ctx context.Context, params json ID: session.ID, RunID: session.RunID, ClaudeSessionID: session.ClaudeSessionID, + ParentSessionID: session.ParentSessionID, Status: session.Status, Query: session.Query, Model: session.Model, @@ -233,10 +234,62 @@ func (h *SessionHandlers) HandleGetSessionState(ctx context.Context, params json }, nil } +// HandleContinueSession handles the ContinueSession RPC method +func (h *SessionHandlers) HandleContinueSession(ctx context.Context, params json.RawMessage) (interface{}, error) { + var req ContinueSessionRequest + if err := json.Unmarshal(params, &req); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + // Validate required fields + if req.SessionID == "" { + return nil, fmt.Errorf("session_id is required") + } + if req.Query == "" { + return nil, fmt.Errorf("query is required") + } + + // Build session config for manager + config := session.ContinueSessionConfig{ + ParentSessionID: req.SessionID, + Query: req.Query, + SystemPrompt: req.SystemPrompt, + AppendSystemPrompt: req.AppendSystemPrompt, + PermissionPromptTool: req.PermissionPromptTool, + AllowedTools: req.AllowedTools, + DisallowedTools: req.DisallowedTools, + CustomInstructions: req.CustomInstructions, + MaxTurns: req.MaxTurns, + } + + // Parse MCP config if provided as JSON string + if req.MCPConfig != "" { + var mcpConfig claudecode.MCPConfig + if err := json.Unmarshal([]byte(req.MCPConfig), &mcpConfig); err != nil { + return nil, fmt.Errorf("invalid mcp_config JSON: %w", err) + } + config.MCPConfig = &mcpConfig + } + + // Continue session + session, err := h.manager.ContinueSession(ctx, config) + if err != nil { + return nil, err + } + + return &ContinueSessionResponse{ + SessionID: session.ID, + RunID: session.RunID, + ClaudeSessionID: "", // Will be populated when events stream in + ParentSessionID: req.SessionID, + }, nil +} + // Register registers all session handlers with the RPC server func (h *SessionHandlers) Register(server *Server) { server.Register("launchSession", h.HandleLaunchSession) server.Register("listSessions", h.HandleListSessions) server.Register("getConversation", h.HandleGetConversation) server.Register("getSessionState", h.HandleGetSessionState) + server.Register("continueSession", h.HandleContinueSession) } diff --git a/hld/rpc/handlers_continue_session_test.go b/hld/rpc/handlers_continue_session_test.go new file mode 100644 index 0000000..f3f45de --- /dev/null +++ b/hld/rpc/handlers_continue_session_test.go @@ -0,0 +1,248 @@ +package rpc + +import ( + "context" + "testing" + "time" + + "github.com/humanlayer/humanlayer/hld/session" + "github.com/humanlayer/humanlayer/hld/store" + "go.uber.org/mock/gomock" +) + +func TestHandleContinueSession(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManager := session.NewMockSessionManager(ctrl) + mockStore := store.NewMockConversationStore(ctrl) + handlers := NewSessionHandlers(mockManager, mockStore) + + testCases := []struct { + name string + request string + setupMocks func() + expectedError string + validateResp func(t *testing.T, resp *ContinueSessionResponse) + }{ + { + name: "successful continue session", + request: `{ + "session_id": "parent-123", + "query": "follow up question", + "system_prompt": "You are helpful", + "max_turns": 5 + }`, + setupMocks: func() { + mockManager.EXPECT().ContinueSession(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, req session.ContinueSessionConfig) (*session.Session, error) { + // Validate request + if req.ParentSessionID != "parent-123" { + t.Errorf("Expected parent session ID 'parent-123', got %s", req.ParentSessionID) + } + if req.Query != "follow up question" { + t.Errorf("Expected query 'follow up question', got %s", req.Query) + } + if req.SystemPrompt != "You are helpful" { + t.Errorf("Expected system prompt 'You are helpful', got %s", req.SystemPrompt) + } + if req.MaxTurns != 5 { + t.Errorf("Expected max turns 5, got %d", req.MaxTurns) + } + + // Return mock session + return &session.Session{ + ID: "child-456", + RunID: "run-child", + Status: session.StatusRunning, + StartTime: time.Now(), + }, nil + }) + + // Note: We don't mock GetSession here because the handler correctly + // returns empty claude_session_id (it's not available until events stream) + }, + validateResp: func(t *testing.T, resp *ContinueSessionResponse) { + if resp.SessionID != "child-456" { + t.Errorf("Expected session ID 'child-456', got %s", resp.SessionID) + } + if resp.RunID != "run-child" { + t.Errorf("Expected run ID 'run-child', got %s", resp.RunID) + } + // claude_session_id should be empty initially (populated when events stream) + if resp.ClaudeSessionID != "" { + t.Errorf("Expected empty claude session ID initially, got %s", resp.ClaudeSessionID) + } + if resp.ParentSessionID != "parent-123" { + t.Errorf("Expected parent session ID 'parent-123', got %s", resp.ParentSessionID) + } + }, + }, + { + name: "continue session with MCP config", + request: `{ + "session_id": "parent-mcp", + "query": "with mcp", + "mcp_config": "{\"mcpServers\": {\"test\": {\"command\": \"node\"}}}" + }`, + setupMocks: func() { + mockManager.EXPECT().ContinueSession(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, req session.ContinueSessionConfig) (*session.Session, error) { + // Validate MCP config was parsed + if req.MCPConfig == nil { + t.Error("Expected MCP config to be parsed") + } + if req.MCPConfig.MCPServers == nil { + t.Error("Expected MCP servers to be set") + } + if _, ok := req.MCPConfig.MCPServers["test"]; !ok { + t.Error("Expected 'test' server in MCP config") + } + + return &session.Session{ + ID: "child-mcp", + RunID: "run-mcp", + Status: session.StatusRunning, + StartTime: time.Now(), + }, nil + }) + + // No GetSession mock needed - claude_session_id won't be available yet + }, + validateResp: func(t *testing.T, resp *ContinueSessionResponse) { + if resp.SessionID != "child-mcp" { + t.Errorf("Expected session ID 'child-mcp', got %s", resp.SessionID) + } + }, + }, + { + name: "missing session ID", + request: `{"query": "no session"}`, + setupMocks: func() { + // No mocks needed - validation fails early + }, + expectedError: "session_id is required", + }, + { + name: "missing query", + request: `{"session_id": "parent-123"}`, + setupMocks: func() { + // No mocks needed - validation fails early + }, + expectedError: "query is required", + }, + { + name: "invalid MCP config JSON", + request: `{"session_id": "parent-123", "query": "test", "mcp_config": "invalid json"}`, + setupMocks: func() { + // No mocks needed - validation fails early + }, + expectedError: "invalid mcp_config JSON", + }, + { + name: "invalid JSON request", + request: `{invalid json}`, + setupMocks: func() { + // No mocks needed - parsing fails + }, + expectedError: "invalid request:", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.setupMocks() + + resp, err := handlers.HandleContinueSession(context.Background(), []byte(tc.request)) + + if tc.expectedError != "" { + if err == nil { + t.Errorf("Expected error containing '%s', got nil", tc.expectedError) + } else if !containsStr(err.Error(), tc.expectedError) { + t.Errorf("Expected error containing '%s', got '%s'", tc.expectedError, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + continueResp, ok := resp.(*ContinueSessionResponse) + if !ok { + t.Fatalf("Expected *ContinueSessionResponse, got %T", resp) + } + + if tc.validateResp != nil { + tc.validateResp(t, continueResp) + } + }) + } +} + +func TestHandleContinueSession_ToolsConfiguration(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockManager := session.NewMockSessionManager(ctrl) + mockStore := store.NewMockConversationStore(ctrl) + handlers := NewSessionHandlers(mockManager, mockStore) + + request := `{ + "session_id": "parent-tools", + "query": "with tools config", + "permission_prompt_tool": "mcp__humanlayer__tool", + "allowed_tools": ["tool1", "tool2"], + "disallowed_tools": ["dangerous_tool"] + }` + + mockManager.EXPECT().ContinueSession(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, req session.ContinueSessionConfig) (*session.Session, error) { + // Validate tools configuration + if req.PermissionPromptTool != "mcp__humanlayer__tool" { + t.Errorf("Expected permission prompt tool 'mcp__humanlayer__tool', got %s", req.PermissionPromptTool) + } + if len(req.AllowedTools) != 2 || req.AllowedTools[0] != "tool1" || req.AllowedTools[1] != "tool2" { + t.Errorf("Expected allowed tools [tool1, tool2], got %v", req.AllowedTools) + } + if len(req.DisallowedTools) != 1 || req.DisallowedTools[0] != "dangerous_tool" { + t.Errorf("Expected disallowed tools [dangerous_tool], got %v", req.DisallowedTools) + } + + return &session.Session{ + ID: "child-tools", + RunID: "run-tools", + Status: session.StatusRunning, + StartTime: time.Now(), + }, nil + }) + + // No GetSession mock needed - claude_session_id won't be available yet + + resp, err := handlers.HandleContinueSession(context.Background(), []byte(request)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + continueResp, ok := resp.(*ContinueSessionResponse) + if !ok { + t.Fatalf("Expected *ContinueSessionResponse, got %T", resp) + } + + if continueResp.SessionID != "child-tools" { + t.Errorf("Expected session ID 'child-tools', got %s", continueResp.SessionID) + } +} + +func containsStr(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && contains(s, substr)) +} + +func contains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/hld/rpc/types.go b/hld/rpc/types.go index d76d91b..11612f5 100644 --- a/hld/rpc/types.go +++ b/hld/rpc/types.go @@ -58,6 +58,7 @@ type SessionState struct { ID string `json:"id"` RunID string `json:"run_id"` ClaudeSessionID string `json:"claude_session_id,omitempty"` + ParentSessionID string `json:"parent_session_id,omitempty"` Status string `json:"status"` // starting, running, completed, failed, waiting_input Query string `json:"query"` Model string `json:"model,omitempty"` @@ -75,3 +76,25 @@ type SessionState struct { type GetSessionStateResponse struct { Session SessionState `json:"session"` } + +// ContinueSessionRequest is the request for continuing an existing session +type ContinueSessionRequest struct { + SessionID string `json:"session_id"` // The session to continue (required) + Query string `json:"query"` // The new query/message to send (required) + SystemPrompt string `json:"system_prompt,omitempty"` // Override system prompt + AppendSystemPrompt string `json:"append_system_prompt,omitempty"` // Append to system prompt + MCPConfig string `json:"mcp_config,omitempty"` // JSON string of MCP config (to avoid import cycle) + PermissionPromptTool string `json:"permission_prompt_tool,omitempty"` // MCP tool for permission prompts + AllowedTools []string `json:"allowed_tools,omitempty"` // Allowed tools list + DisallowedTools []string `json:"disallowed_tools,omitempty"` // Disallowed tools list + CustomInstructions string `json:"custom_instructions,omitempty"` // Custom instructions + MaxTurns int `json:"max_turns,omitempty"` // Max conversation turns +} + +// ContinueSessionResponse is the response for continuing a session +type ContinueSessionResponse struct { + SessionID string `json:"session_id"` // The new session ID + RunID string `json:"run_id"` // The new run ID + ClaudeSessionID string `json:"claude_session_id"` // The new Claude session ID (unique for each resume) + ParentSessionID string `json:"parent_session_id"` // The parent session ID +} diff --git a/hld/session/manager.go b/hld/session/manager.go index 5d9d052..09a30c2 100644 --- a/hld/session/manager.go +++ b/hld/session/manager.go @@ -437,3 +437,143 @@ func (m *Manager) processStreamEvent(ctx context.Context, sessionID string, clau return nil } + +// ContinueSession resumes an existing completed session with a new query and optional config overrides +func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig) (*Session, error) { + // Get parent session from database + parentSession, err := m.store.GetSession(ctx, req.ParentSessionID) + if err != nil { + return nil, fmt.Errorf("failed to get parent session: %w", err) + } + + // Validate parent session status + if parentSession.Status != store.SessionStatusCompleted { + return nil, fmt.Errorf("cannot continue session with status %s (must be completed)", parentSession.Status) + } + + // Validate parent session has claude_session_id + if parentSession.ClaudeSessionID == "" { + return nil, fmt.Errorf("parent session missing claude_session_id (cannot resume)") + } + + // Build config for resumed session + // Start with minimal required fields + config := claudecode.SessionConfig{ + Query: req.Query, + SessionID: parentSession.ClaudeSessionID, // This triggers --resume flag + OutputFormat: claudecode.OutputStreamJSON, // Always use streaming JSON + // Model and WorkingDir are inherited from Claude's internal state + } + + // Apply optional overrides + if req.SystemPrompt != "" { + config.SystemPrompt = req.SystemPrompt + } + if req.AppendSystemPrompt != "" { + config.AppendSystemPrompt = req.AppendSystemPrompt + } + if req.MCPConfig != nil { + config.MCPConfig = req.MCPConfig + } + if req.PermissionPromptTool != "" { + config.PermissionPromptTool = req.PermissionPromptTool + } + if len(req.AllowedTools) > 0 { + config.AllowedTools = req.AllowedTools + } + if len(req.DisallowedTools) > 0 { + config.DisallowedTools = req.DisallowedTools + } + if req.CustomInstructions != "" { + config.CustomInstructions = req.CustomInstructions + } + if req.MaxTurns > 0 { + config.MaxTurns = req.MaxTurns + } + + // Create new session with parent reference + sessionID := uuid.New().String() + runID := uuid.New().String() + + // Store session in database with parent reference + dbSession := store.NewSessionFromConfig(sessionID, runID, config) + dbSession.ParentSessionID = req.ParentSessionID + // Note: ClaudeSessionID will be captured from streaming events (will be different from parent) + if err := m.store.CreateSession(ctx, dbSession); err != nil { + return nil, fmt.Errorf("failed to store session in database: %w", err) + } + + // Add run_id to MCP server environments + if config.MCPConfig != nil { + for name, server := range config.MCPConfig.MCPServers { + if server.Env == nil { + server.Env = make(map[string]string) + } + server.Env["HUMANLAYER_RUN_ID"] = runID + config.MCPConfig.MCPServers[name] = server + } + + // Store MCP servers configuration + servers, err := store.MCPServersFromConfig(sessionID, config.MCPConfig.MCPServers) + if err != nil { + slog.Error("failed to convert MCP servers", "error", err) + } else if err := m.store.StoreMCPServers(ctx, sessionID, servers); err != nil { + slog.Error("failed to store MCP servers", "error", err) + } + } + + // Launch resumed Claude session + claudeSession, err := m.client.Launch(config) + if err != nil { + m.updateSessionStatus(ctx, sessionID, StatusFailed, err.Error()) + return nil, fmt.Errorf("failed to launch resumed Claude session: %w", err) + } + + // Store active Claude process + m.mu.Lock() + m.activeProcesses[sessionID] = claudeSession + m.mu.Unlock() + + // Update database with running status + statusRunning := string(StatusRunning) + now := time.Now() + update := store.SessionUpdate{ + Status: &statusRunning, + LastActivityAt: &now, + } + if err := m.store.UpdateSession(ctx, sessionID, update); err != nil { + slog.Error("failed to update session status to running", "error", err) + } + + // Publish status change event + if m.eventBus != nil { + m.eventBus.Publish(bus.Event{ + Type: bus.EventSessionStatusChanged, + Data: map[string]interface{}{ + "session_id": sessionID, + "run_id": runID, + "parent_session_id": req.ParentSessionID, + "old_status": string(StatusStarting), + "new_status": string(StatusRunning), + }, + }) + } + + // Monitor session lifecycle in background + go m.monitorSession(ctx, sessionID, runID, claudeSession, time.Now(), config) + + slog.Info("continued Claude session", + "session_id", sessionID, + "parent_session_id", req.ParentSessionID, + "run_id", runID, + "query", req.Query) + + // Return minimal session info + return &Session{ + ID: sessionID, + RunID: runID, + Status: StatusRunning, + StartTime: time.Now(), + Config: config, + }, nil +} diff --git a/hld/session/manager_test.go b/hld/session/manager_test.go index 0fc361d..be2084d 100644 --- a/hld/session/manager_test.go +++ b/hld/session/manager_test.go @@ -1,6 +1,7 @@ package session import ( + "context" "fmt" "testing" "time" @@ -114,3 +115,280 @@ func TestGetSessionInfo(t *testing.T) { // Note: Most of the old tests were removed because they tested internal implementation // details (in-memory maps) that no longer exist. The real functionality is now // tested by the integration tests which use actual SQLite database. + +func TestContinueSession_ValidatesParentExists(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockConversationStore(ctrl) + manager, _ := NewManager(nil, mockStore) + + // Test parent not found + mockStore.EXPECT().GetSession(gomock.Any(), "not-found").Return(nil, fmt.Errorf("session not found")) + + req := ContinueSessionConfig{ + ParentSessionID: "not-found", + Query: "continue this", + } + _, err := manager.ContinueSession(context.Background(), req) + if err == nil { + t.Error("Expected error for non-existent parent session") + } + if err.Error() != "failed to get parent session: session not found" { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestContinueSession_ValidatesParentStatus(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockConversationStore(ctrl) + manager, _ := NewManager(nil, mockStore) + + testCases := []struct { + name string + parentStatus string + expectedError string + }{ + { + name: "running session", + parentStatus: store.SessionStatusRunning, + expectedError: "cannot continue session with status running (must be completed)", + }, + { + name: "failed session", + parentStatus: store.SessionStatusFailed, + expectedError: "cannot continue session with status failed (must be completed)", + }, + { + name: "starting session", + parentStatus: store.SessionStatusStarting, + expectedError: "cannot continue session with status starting (must be completed)", + }, + { + name: "waiting input session", + parentStatus: store.SessionStatusWaitingInput, + expectedError: "cannot continue session with status waiting_input (must be completed)", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + parentSession := &store.Session{ + ID: "parent-1", + RunID: "run-1", + ClaudeSessionID: "claude-1", + Status: tc.parentStatus, + Query: "original query", + CreatedAt: time.Now(), + } + mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil) + + req := ContinueSessionConfig{ + ParentSessionID: "parent-1", + Query: "continue this", + } + _, err := manager.ContinueSession(context.Background(), req) + if err == nil { + t.Error("Expected error for non-completed parent session") + } + if err.Error() != tc.expectedError { + t.Errorf("Expected error '%s', got: %v", tc.expectedError, err) + } + }) + } +} + +func TestContinueSession_ValidatesClaudeSessionID(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockConversationStore(ctrl) + manager, _ := NewManager(nil, mockStore) + + // Parent without claude_session_id + parentSession := &store.Session{ + ID: "parent-1", + RunID: "run-1", + ClaudeSessionID: "", // Empty + Status: store.SessionStatusCompleted, + Query: "original query", + CreatedAt: time.Now(), + } + mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil) + + req := ContinueSessionConfig{ + ParentSessionID: "parent-1", + Query: "continue this", + } + _, err := manager.ContinueSession(context.Background(), req) + if err == nil { + t.Error("Expected error for parent without claude_session_id") + } + if err.Error() != "parent session missing claude_session_id (cannot resume)" { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestContinueSession_CreatesNewSessionWithParentReference(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockConversationStore(ctrl) + manager, _ := NewManager(nil, mockStore) + + // Mock parent session + parentSession := &store.Session{ + ID: "parent-1", + RunID: "run-1", + ClaudeSessionID: "claude-1", + Status: store.SessionStatusCompleted, + Query: "original query", + Model: "claude-3-opus", + WorkingDir: "/test/dir", + CreatedAt: time.Now(), + } + mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil) + + // Expect session creation with parent reference + mockStore.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx interface{}, session *store.Session) error { + // Validate the created session + if session.ParentSessionID != "parent-1" { + t.Errorf("Expected parent_session_id to be 'parent-1', got '%s'", session.ParentSessionID) + } + if session.Query != "continue this" { + t.Errorf("Expected query 'continue this', got '%s'", session.Query) + } + if session.Status != store.SessionStatusStarting { + t.Errorf("Expected status 'starting', got '%s'", session.Status) + } + // Should not inherit claude_session_id (will be set from streaming events) + if session.ClaudeSessionID != "" { + t.Errorf("Expected empty claude_session_id, got '%s'", session.ClaudeSessionID) + } + return nil + }) + + // Expect status update to running (we can't test the full flow without mocking Claude client) + // May be called twice if Claude fails to launch in background + mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + + req := ContinueSessionConfig{ + ParentSessionID: "parent-1", + Query: "continue this", + } + + // Try to continue session - this tests our logic, not Claude launch + session, err := manager.ContinueSession(context.Background(), req) + + // If Claude binary exists, it might succeed; if not, it will fail + // Either way, our mock expectations should have been met (session created with parent) + if err != nil { + // Expected - Claude binary might not exist in test environment + if !containsError(err, "failed to launch resumed Claude session") { + t.Errorf("Unexpected error: %v", err) + } + } else { + // Claude launched successfully - verify session has expected properties + if session.ID == "" { + t.Error("Expected session ID to be set") + } + if session.RunID == "" { + t.Error("Expected run ID to be set") + } + } +} + +func TestContinueSession_HandlesOptionalOverrides(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockConversationStore(ctrl) + manager, _ := NewManager(nil, mockStore) + + // Mock parent session + parentSession := &store.Session{ + ID: "parent-1", + RunID: "run-1", + ClaudeSessionID: "claude-1", + Status: store.SessionStatusCompleted, + Query: "original query", + CreatedAt: time.Now(), + } + mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil) + + // Test with various overrides + req := ContinueSessionConfig{ + ParentSessionID: "parent-1", + Query: "continue with overrides", + SystemPrompt: "You are a pirate", + AppendSystemPrompt: "Always say arr", + PermissionPromptTool: "mcp__custom__tool", + AllowedTools: []string{"tool1", "tool2"}, + DisallowedTools: []string{"dangerous_tool"}, + CustomInstructions: "Be helpful", + MaxTurns: 5, + } + + // Expect session creation + mockStore.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx interface{}, session *store.Session) error { + // Validate overrides are stored + if session.SystemPrompt != "You are a pirate" { + t.Errorf("Expected system prompt override, got '%s'", session.SystemPrompt) + } + if session.CustomInstructions != "Be helpful" { + t.Errorf("Expected custom instructions override, got '%s'", session.CustomInstructions) + } + if session.MaxTurns != 5 { + t.Errorf("Expected max turns 5, got %d", session.MaxTurns) + } + return nil + }) + + // Expect status update + // May be called twice if Claude fails to launch in background + mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + + session, err := manager.ContinueSession(context.Background(), req) + + // Test passes if our mock expectations were met (session created with overrides) + // Whether Claude actually launches depends on the environment + if err != nil { + // Expected - Claude binary might not exist + if !containsError(err, "failed to launch resumed Claude session") { + t.Errorf("Unexpected error: %v", err) + } + } else { + // Claude launched - verify session properties + if session.ID == "" { + t.Error("Expected session ID to be set") + } + if session.RunID == "" { + t.Error("Expected run ID to be set") + } + } +} + +// Helper function to check if error contains a substring +func containsError(err error, substr string) bool { + if err == nil { + return false + } + return contains(err.Error(), substr) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsStr(s, substr)) +} + +func containsStr(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/hld/session/types.go b/hld/session/types.go index f8c6feb..f08ff9a 100644 --- a/hld/session/types.go +++ b/hld/session/types.go @@ -42,11 +42,28 @@ type Info struct { Result *claudecode.Result `json:"result,omitempty"` } +// ContinueSessionConfig contains the configuration for continuing a session +type ContinueSessionConfig struct { + ParentSessionID string // The parent session to resume from + Query string // The new query + SystemPrompt string // Optional system prompt override + AppendSystemPrompt string // Optional append to system prompt + MCPConfig *claudecode.MCPConfig // Optional MCP config override + PermissionPromptTool string // Optional permission prompt tool + AllowedTools []string // Optional allowed tools override + DisallowedTools []string // Optional disallowed tools override + CustomInstructions string // Optional custom instructions + MaxTurns int // Optional max turns override +} + // SessionManager defines the interface for managing Claude Code sessions type SessionManager interface { // LaunchSession starts a new Claude Code session LaunchSession(ctx context.Context, config claudecode.SessionConfig) (*Session, error) + // ContinueSession resumes an existing completed session with a new query and optional config overrides + ContinueSession(ctx context.Context, req ContinueSessionConfig) (*Session, error) + // GetSessionInfo returns session info from the database by ID GetSessionInfo(sessionID string) (*Info, error) diff --git a/hld/store/sqlite.go b/hld/store/sqlite.go index 72d6725..e24a8ee 100644 --- a/hld/store/sqlite.go +++ b/hld/store/sqlite.go @@ -8,6 +8,7 @@ import ( "log/slog" "os" "path/filepath" + "strings" claudecode "github.com/humanlayer/humanlayer/claudecode-go" _ "github.com/mattn/go-sqlite3" @@ -523,24 +524,98 @@ func (s *SQLiteStore) GetConversation(ctx context.Context, claudeSessionID strin return events, nil } -// GetSessionConversation retrieves all events for a session +// GetSessionConversation retrieves all events for a session including parent history func (s *SQLiteStore) GetSessionConversation(ctx context.Context, sessionID string) ([]*ConversationEvent, error) { - // First get the claude_session_id for this session - var claudeSessionID sql.NullString - err := s.db.QueryRowContext(ctx, - "SELECT claude_session_id FROM sessions WHERE id = ?", - sessionID, - ).Scan(&claudeSessionID) - if err != nil { - return nil, fmt.Errorf("failed to get session: %w", err) + // Walk up the parent chain to get all related claude session IDs + claudeSessionIDs := []string{} + currentID := sessionID + + for currentID != "" { + var claudeSessionID sql.NullString + var parentID sql.NullString + + err := s.db.QueryRowContext(ctx, + "SELECT claude_session_id, parent_session_id FROM sessions WHERE id = ?", + currentID, + ).Scan(&claudeSessionID, &parentID) + if err != nil { + if err == sql.ErrNoRows { + break // Session not found, stop walking + } + return nil, fmt.Errorf("failed to get session: %w", err) + } + + // Add claude session ID if present (in reverse order for chronological events) + if claudeSessionID.Valid && claudeSessionID.String != "" { + claudeSessionIDs = append([]string{claudeSessionID.String}, claudeSessionIDs...) + } + + // Move to parent + if parentID.Valid { + currentID = parentID.String + } else { + currentID = "" + } } - if !claudeSessionID.Valid { - // No claude session yet, return empty + if len(claudeSessionIDs) == 0 { + // No claude sessions yet, return empty return []*ConversationEvent{}, nil } - return s.GetConversation(ctx, claudeSessionID.String) + // Get all events for all claude session IDs in chronological order + placeholders := make([]string, len(claudeSessionIDs)) + args := make([]interface{}, len(claudeSessionIDs)) + for i, id := range claudeSessionIDs { + placeholders[i] = "?" + args[i] = id + } + + // Build query that orders by the position in the claude session ID list first + // This ensures parent events come before child events + orderCases := make([]string, len(claudeSessionIDs)) + for i := range claudeSessionIDs { + orderCases[i] = fmt.Sprintf("WHEN claude_session_id = ? THEN %d", i) + args = append(args, claudeSessionIDs[i]) + } + + query := fmt.Sprintf(` + SELECT id, session_id, claude_session_id, sequence, event_type, created_at, + role, content, + tool_id, tool_name, tool_input_json, + tool_result_for_id, tool_result_content, + is_completed, approval_status, approval_id + FROM conversation_events + WHERE claude_session_id IN (%s) + ORDER BY + CASE %s END, + sequence + `, strings.Join(placeholders, ","), strings.Join(orderCases, " ")) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to get conversation: %w", err) + } + defer func() { _ = rows.Close() }() + + var events []*ConversationEvent + for rows.Next() { + event := &ConversationEvent{} + err := rows.Scan( + &event.ID, &event.SessionID, &event.ClaudeSessionID, + &event.Sequence, &event.EventType, &event.CreatedAt, + &event.Role, &event.Content, + &event.ToolID, &event.ToolName, &event.ToolInputJSON, + &event.ToolResultForID, &event.ToolResultContent, + &event.IsCompleted, &event.ApprovalStatus, &event.ApprovalID, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan event: %w", err) + } + events = append(events, event) + } + + return events, nil } // GetPendingToolCall finds the most recent uncompleted tool call for a given session and tool name diff --git a/hld/store/sqlite_test.go b/hld/store/sqlite_test.go index cf0784b..b054419 100644 --- a/hld/store/sqlite_test.go +++ b/hld/store/sqlite_test.go @@ -262,3 +262,203 @@ func TestSQLiteStore(t *testing.T) { require.Len(t, sessions, 2) }) } + +func TestGetSessionConversationWithParentChain(t *testing.T) { + // Create temp database + tmpDir, err := os.MkdirTemp("", "hld-test-parent-*") + require.NoError(t, err) + defer func() { _ = os.RemoveAll(tmpDir) }() + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewSQLiteStore(dbPath) + require.NoError(t, err) + defer func() { _ = store.Close() }() + + ctx := context.Background() + + // Create parent session + parentSession := &Session{ + ID: "parent-1", + RunID: "run-1", + ClaudeSessionID: "claude-parent", + Query: "Tell me about Go", + Status: SessionStatusCompleted, + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + } + err = store.CreateSession(ctx, parentSession) + require.NoError(t, err) + + // Add events to parent session + parentEvents := []*ConversationEvent{ + { + SessionID: "parent-1", + ClaudeSessionID: "claude-parent", + Sequence: 1, + EventType: EventTypeMessage, + Role: "user", + Content: "Tell me about Go", + CreatedAt: time.Now(), + }, + { + SessionID: "parent-1", + ClaudeSessionID: "claude-parent", + Sequence: 2, + EventType: EventTypeMessage, + Role: "assistant", + Content: "Go is a statically typed programming language...", + CreatedAt: time.Now().Add(1 * time.Second), + }, + } + for _, event := range parentEvents { + err = store.AddConversationEvent(ctx, event) + require.NoError(t, err) + } + + // Create child session + childSession := &Session{ + ID: "child-1", + RunID: "run-2", + ClaudeSessionID: "claude-child", + ParentSessionID: "parent-1", + Query: "Tell me more about goroutines", + Status: SessionStatusCompleted, + CreatedAt: time.Now().Add(5 * time.Second), + LastActivityAt: time.Now().Add(5 * time.Second), + } + err = store.CreateSession(ctx, childSession) + require.NoError(t, err) + + // Add events to child session + childEvents := []*ConversationEvent{ + { + SessionID: "child-1", + ClaudeSessionID: "claude-child", + Sequence: 1, + EventType: EventTypeMessage, + Role: "user", + Content: "Tell me more about goroutines", + CreatedAt: time.Now().Add(10 * time.Second), + }, + { + SessionID: "child-1", + ClaudeSessionID: "claude-child", + Sequence: 2, + EventType: EventTypeMessage, + Role: "assistant", + Content: "Goroutines are lightweight threads...", + CreatedAt: time.Now().Add(11 * time.Second), + }, + } + for _, event := range childEvents { + err = store.AddConversationEvent(ctx, event) + require.NoError(t, err) + } + + // Create grandchild session + grandchildSession := &Session{ + ID: "grandchild-1", + RunID: "run-3", + ClaudeSessionID: "claude-grandchild", + ParentSessionID: "child-1", + Query: "How do channels work?", + Status: SessionStatusRunning, + CreatedAt: time.Now().Add(20 * time.Second), + LastActivityAt: time.Now().Add(20 * time.Second), + } + err = store.CreateSession(ctx, grandchildSession) + require.NoError(t, err) + + // Add events to grandchild session + grandchildEvents := []*ConversationEvent{ + { + SessionID: "grandchild-1", + ClaudeSessionID: "claude-grandchild", + Sequence: 1, + EventType: EventTypeMessage, + Role: "user", + Content: "How do channels work?", + CreatedAt: time.Now().Add(25 * time.Second), + }, + { + SessionID: "grandchild-1", + ClaudeSessionID: "claude-grandchild", + Sequence: 2, + EventType: EventTypeMessage, + Role: "assistant", + Content: "Channels are Go's way of communication...", + CreatedAt: time.Now().Add(26 * time.Second), + }, + } + for _, event := range grandchildEvents { + err = store.AddConversationEvent(ctx, event) + require.NoError(t, err) + } + + t.Run("GetSessionConversation_IncludesFullHistory", func(t *testing.T) { + // Get conversation for grandchild - should include all parent events + events, err := store.GetSessionConversation(ctx, "grandchild-1") + require.NoError(t, err) + require.Len(t, events, 6) // 2 from parent + 2 from child + 2 from grandchild + + // Verify chronological order + expectedContents := []string{ + "Tell me about Go", + "Go is a statically typed programming language...", + "Tell me more about goroutines", + "Goroutines are lightweight threads...", + "How do channels work?", + "Channels are Go's way of communication...", + } + + for i, event := range events { + require.Equal(t, expectedContents[i], event.Content) + } + }) + + t.Run("GetSessionConversation_SessionWithoutClaudeID", func(t *testing.T) { + // Create session without claude_session_id yet + newSession := &Session{ + ID: "new-session", + RunID: "run-4", + Query: "New query", + Status: SessionStatusStarting, + CreatedAt: time.Now(), + LastActivityAt: time.Now(), + } + err = store.CreateSession(ctx, newSession) + require.NoError(t, err) + + // Should return empty events + events, err := store.GetSessionConversation(ctx, "new-session") + require.NoError(t, err) + require.Len(t, events, 0) + }) + + t.Run("GetSessionConversation_NonExistentSession", func(t *testing.T) { + // Should return empty events for non-existent session + events, err := store.GetSessionConversation(ctx, "does-not-exist") + require.NoError(t, err) + require.Len(t, events, 0) + }) + + t.Run("GetSessionConversation_NoParent", func(t *testing.T) { + // Get conversation for parent session (no parents) + events, err := store.GetSessionConversation(ctx, "parent-1") + require.NoError(t, err) + require.Len(t, events, 2) // Only parent's events + }) + + t.Run("GetSessionConversation_MiddleOfChain", func(t *testing.T) { + // Get conversation for child (middle of chain) + events, err := store.GetSessionConversation(ctx, "child-1") + require.NoError(t, err) + require.Len(t, events, 4) // Parent's 2 + child's 2 + + // Verify we get parent events first, then child + require.Equal(t, "Tell me about Go", events[0].Content) + require.Equal(t, "Go is a statically typed programming language...", events[1].Content) + require.Equal(t, "Tell me more about goroutines", events[2].Content) + require.Equal(t, "Goroutines are lightweight threads...", events[3].Content) + }) +}