mirror of
https://github.com/humanlayer/humanlayer.git
synced 2025-08-20 19:01:22 +03:00
Add multi-turn conversation support with session continuation (#207)
* resuming session work * pull request template formatting * make claudecode-go happy with race conditions * tests for session continuation
This commit is contained in:
43
.github/PULL_REQUEST_TEMPLATE.md
vendored
43
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,44 +1,13 @@
|
||||
<!--
|
||||
Please make sure you've read and understood our contributing guidelines;
|
||||
https://github.com/humanlayer/humanlayer/blob/master/CONTRIBUTING.md
|
||||
## What problem(s) was I solving?
|
||||
|
||||
If this is a bug fix, make sure your description includes "fixes #xxxx", or
|
||||
"closes #xxxx"
|
||||
## What user-facing changes did I ship?
|
||||
|
||||
Please provide the following information:
|
||||
## How I implemented it
|
||||
|
||||
-->
|
||||
|
||||
**What I did**
|
||||
|
||||
<!--
|
||||
A succint description of the user-facing changes, at most 2-3 bullet points.
|
||||
Less about the code changes that happened, more about the outcomes this work drives
|
||||
-->
|
||||
|
||||
**How I did it**
|
||||
|
||||
<!--
|
||||
Describe the work you did, tests you ran, changes you made, etc
|
||||
-->
|
||||
## How to verify it
|
||||
|
||||
- [ ] I have ensured `make check test` passes
|
||||
|
||||
**How to verify it**
|
||||
## Description for the changelog
|
||||
|
||||
<!--
|
||||
Describe how to test this, which examples to run to verify it, or anything
|
||||
else that would be helpful to folks exploring this work
|
||||
-->
|
||||
|
||||
**Description for the changelog**
|
||||
|
||||
<!--
|
||||
Write a short (one line) summary that describes the changes in this
|
||||
pull request for inclusion in the changelog:
|
||||
-->
|
||||
|
||||
<!--
|
||||
**- A picture of a cute animal (not mandatory but encouraged)**
|
||||
|
||||
-->
|
||||
## A picture of a cute animal (not mandatory but encouraged)
|
||||
|
||||
@@ -185,7 +185,7 @@ func (c *Client) Launch(config SessionConfig) (*Session, error) {
|
||||
// Wait for process to complete in background
|
||||
go func() {
|
||||
// Wait for the command to exit
|
||||
session.err = cmd.Wait()
|
||||
session.SetError(cmd.Wait())
|
||||
|
||||
// IMPORTANT: Wait for parsing to complete before signaling done.
|
||||
// This ensures that all output has been read and processed before
|
||||
@@ -213,8 +213,8 @@ func (c *Client) LaunchAndWait(config SessionConfig) (*Result, error) {
|
||||
func (s *Session) Wait() (*Result, error) {
|
||||
<-s.done
|
||||
|
||||
if s.err != nil && s.result == nil {
|
||||
return nil, fmt.Errorf("claude process failed: %w", s.err)
|
||||
if err := s.Error(); err != nil && s.result == nil {
|
||||
return nil, fmt.Errorf("claude process failed: %w", err)
|
||||
}
|
||||
|
||||
return s.result, nil
|
||||
@@ -232,9 +232,11 @@ func (s *Session) Kill() error {
|
||||
func (s *Session) parseStreamingJSON(stdout, stderr io.Reader) {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
var stderrBuf strings.Builder
|
||||
stderrDone := make(chan struct{})
|
||||
|
||||
// Capture stderr in background
|
||||
go func() {
|
||||
defer close(stderrDone)
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, err := stderr.Read(buf)
|
||||
@@ -283,9 +285,12 @@ func (s *Session) parseStreamingJSON(stdout, stderr io.Reader) {
|
||||
s.Events <- event
|
||||
}
|
||||
|
||||
// Wait for stderr reading to complete before accessing the buffer
|
||||
<-stderrDone
|
||||
|
||||
// If we got stderr output, that's an error
|
||||
if stderrOutput := stderrBuf.String(); stderrOutput != "" {
|
||||
s.err = fmt.Errorf("claude error: %s", stderrOutput)
|
||||
s.SetError(fmt.Errorf("claude error: %s", stderrOutput))
|
||||
}
|
||||
|
||||
// Close events channel when done parsing
|
||||
@@ -296,7 +301,7 @@ func (s *Session) parseStreamingJSON(stdout, stderr io.Reader) {
|
||||
func (s *Session) parseSingleJSON(stdout, stderr io.Reader) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.err = fmt.Errorf("panic in parseSingleJSON: %v", r)
|
||||
s.SetError(fmt.Errorf("panic in parseSingleJSON: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -304,26 +309,26 @@ func (s *Session) parseSingleJSON(stdout, stderr io.Reader) {
|
||||
|
||||
// Read all stdout
|
||||
if _, err := io.Copy(&stdoutBuf, stdout); err != nil {
|
||||
s.err = fmt.Errorf("failed to read stdout: %w", err)
|
||||
s.SetError(fmt.Errorf("failed to read stdout: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Read all stderr
|
||||
if _, err := io.Copy(&stderrBuf, stderr); err != nil {
|
||||
s.err = fmt.Errorf("failed to read stderr: %w", err)
|
||||
s.SetError(fmt.Errorf("failed to read stderr: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON result
|
||||
output := stdoutBuf.String()
|
||||
if output == "" {
|
||||
s.err = fmt.Errorf("no output from claude")
|
||||
s.SetError(fmt.Errorf("no output from claude"))
|
||||
return
|
||||
}
|
||||
|
||||
var result Result
|
||||
if err := json.Unmarshal([]byte(output), &result); err != nil {
|
||||
s.err = fmt.Errorf("failed to parse JSON output: %w\nOutput was: %s", err, output)
|
||||
s.SetError(fmt.Errorf("failed to parse JSON output: %w\nOutput was: %s", err, output))
|
||||
return
|
||||
}
|
||||
s.result = &result
|
||||
@@ -333,7 +338,7 @@ func (s *Session) parseSingleJSON(stdout, stderr io.Reader) {
|
||||
if stderrOutput := stderrBuf.String(); stderrOutput != "" {
|
||||
// Don't override result if we got valid JSON
|
||||
if s.result == nil {
|
||||
s.err = fmt.Errorf("claude error: %s", stderrOutput)
|
||||
s.SetError(fmt.Errorf("claude error: %s", stderrOutput))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -344,13 +349,13 @@ func (s *Session) parseTextOutput(stdout, stderr io.Reader) {
|
||||
|
||||
// Read all stdout
|
||||
if _, err := io.Copy(&stdoutBuf, stdout); err != nil {
|
||||
s.err = fmt.Errorf("failed to read stdout: %w", err)
|
||||
s.SetError(fmt.Errorf("failed to read stdout: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Read all stderr
|
||||
if _, err := io.Copy(&stderrBuf, stderr); err != nil {
|
||||
s.err = fmt.Errorf("failed to read stderr: %w", err)
|
||||
s.SetError(fmt.Errorf("failed to read stderr: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -365,6 +370,6 @@ func (s *Session) parseTextOutput(stdout, stderr io.Reader) {
|
||||
|
||||
// If we got stderr output, that's an error
|
||||
if stderrOutput := stderrBuf.String(); stderrOutput != "" {
|
||||
s.err = fmt.Errorf("claude error: %s", stderrOutput)
|
||||
s.SetError(fmt.Errorf("claude error: %s", stderrOutput))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package claudecode
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -140,5 +141,24 @@ type Session struct {
|
||||
cmd *exec.Cmd
|
||||
done chan struct{}
|
||||
result *Result
|
||||
|
||||
// Thread-safe error handling
|
||||
mu sync.RWMutex
|
||||
err error
|
||||
}
|
||||
|
||||
// SetError safely sets the error
|
||||
func (s *Session) SetError(err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
}
|
||||
|
||||
// Error safely gets the error
|
||||
func (s *Session) Error() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
@@ -256,6 +256,15 @@ func (c *client) ListSessions() (*rpc.ListSessionsResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ContinueSession continues an existing completed session with a new query
|
||||
func (c *client) ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueSessionResponse, error) {
|
||||
var resp rpc.ContinueSessionResponse
|
||||
if err := c.call("continueSession", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// FetchApprovals fetches pending approvals from the daemon
|
||||
func (c *client) FetchApprovals(sessionID string) ([]approval.PendingApproval, error) {
|
||||
req := rpc.FetchApprovalsRequest{
|
||||
|
||||
@@ -18,6 +18,9 @@ type Client interface {
|
||||
// ListSessions lists all active sessions
|
||||
ListSessions() (*rpc.ListSessionsResponse, error)
|
||||
|
||||
// ContinueSession continues an existing completed session with a new query
|
||||
ContinueSession(req rpc.ContinueSessionRequest) (*rpc.ContinueSessionResponse, error)
|
||||
|
||||
// FetchApprovals fetches pending approvals from the daemon
|
||||
FetchApprovals(sessionID string) ([]approval.PendingApproval, error)
|
||||
|
||||
|
||||
515
hld/daemon/daemon_continue_session_integration_test.go
Normal file
515
hld/daemon/daemon_continue_session_integration_test.go
Normal file
@@ -0,0 +1,515 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/bus"
|
||||
"github.com/humanlayer/humanlayer/hld/config"
|
||||
"github.com/humanlayer/humanlayer/hld/internal/testutil"
|
||||
"github.com/humanlayer/humanlayer/hld/rpc"
|
||||
"github.com/humanlayer/humanlayer/hld/session"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
)
|
||||
|
||||
func TestIntegrationContinueSession(t *testing.T) {
|
||||
// Use test-specific socket path
|
||||
socketPath := testutil.SocketPath(t, "continue-session")
|
||||
|
||||
// Create daemon components
|
||||
eventBus := bus.NewEventBus()
|
||||
sqliteStore, err := store.NewSQLiteStore(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create store: %v", err)
|
||||
}
|
||||
defer func() { _ = sqliteStore.Close() }()
|
||||
|
||||
sessionManager, err := session.NewManager(eventBus, sqliteStore)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create daemon
|
||||
d := &Daemon{
|
||||
socketPath: socketPath,
|
||||
config: &config.Config{SocketPath: socketPath, DatabasePath: ":memory:"},
|
||||
eventBus: eventBus,
|
||||
store: sqliteStore,
|
||||
sessions: sessionManager,
|
||||
rpcServer: rpc.NewServer(),
|
||||
}
|
||||
|
||||
// Register RPC handlers
|
||||
sessionHandlers := rpc.NewSessionHandlers(sessionManager, sqliteStore)
|
||||
sessionHandlers.Register(d.rpcServer)
|
||||
|
||||
// Start daemon
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := d.Run(ctx); err != nil {
|
||||
t.Logf("daemon run error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for daemon to be ready
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Create helper function to send RPC requests
|
||||
sendRPC := func(t *testing.T, method string, params interface{}) (json.RawMessage, error) {
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to daemon: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
request := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
"id": 1,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
if _, err := conn.Write(append(data, '\n')); err != nil {
|
||||
t.Fatalf("failed to write request: %v", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(conn)
|
||||
if !scanner.Scan() {
|
||||
t.Fatal("no response received")
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if errObj, ok := response["error"]; ok {
|
||||
if errMap, ok := errObj.(map[string]interface{}); ok {
|
||||
if msg, ok := errMap["message"].(string); ok {
|
||||
return nil, fmt.Errorf("%s", msg)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("RPC error: %v", errObj)
|
||||
}
|
||||
|
||||
if result, ok := response["result"]; ok {
|
||||
resultBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal result: %v", err)
|
||||
}
|
||||
return resultBytes, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no result in response")
|
||||
}
|
||||
|
||||
t.Run("ContinueSession_RequiresCompletedParent", func(t *testing.T) {
|
||||
// Create a parent session that's still running
|
||||
parentSessionID := "parent-running"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-parent",
|
||||
ClaudeSessionID: "claude-parent",
|
||||
Status: store.SessionStatusRunning, // Not completed
|
||||
Query: "original query",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
}
|
||||
|
||||
// Insert parent session directly into database
|
||||
if err := d.store.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Try to continue the running session
|
||||
req := rpc.ContinueSessionRequest{
|
||||
SessionID: parentSessionID,
|
||||
Query: "continue this",
|
||||
}
|
||||
|
||||
_, err := sendRPC(t, "continueSession", req)
|
||||
if err == nil {
|
||||
t.Error("Expected error when continuing running session")
|
||||
}
|
||||
if err.Error() != "cannot continue session with status running (must be completed)" {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContinueSession_RequiresClaudeSessionID", func(t *testing.T) {
|
||||
// Create a parent session without claude_session_id
|
||||
parentSessionID := "parent-no-claude"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-no-claude",
|
||||
ClaudeSessionID: "", // Missing
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original query",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
// Insert parent session
|
||||
if err := d.store.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Try to continue without claude_session_id
|
||||
req := rpc.ContinueSessionRequest{
|
||||
SessionID: parentSessionID,
|
||||
Query: "continue this",
|
||||
}
|
||||
|
||||
_, err := sendRPC(t, "continueSession", req)
|
||||
if err == nil {
|
||||
t.Error("Expected error when continuing session without claude_session_id")
|
||||
}
|
||||
if err.Error() != "parent session missing claude_session_id (cannot resume)" {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContinueSession_CreatesChildSession", func(t *testing.T) {
|
||||
// Create a valid completed parent session
|
||||
parentSessionID := "parent-valid"
|
||||
claudeSessionID := "claude-valid"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-valid",
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original query",
|
||||
Model: "claude-3-opus",
|
||||
WorkingDir: "/test/dir",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
// Insert parent session
|
||||
if err := d.store.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Add some conversation history to parent
|
||||
events := []*store.ConversationEvent{
|
||||
{
|
||||
SessionID: parentSessionID,
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "user",
|
||||
Content: "original query",
|
||||
},
|
||||
{
|
||||
SessionID: parentSessionID,
|
||||
ClaudeSessionID: claudeSessionID,
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "assistant",
|
||||
Content: "Original response",
|
||||
},
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
if err := d.store.AddConversationEvent(ctx, event); err != nil {
|
||||
t.Fatalf("Failed to add conversation event: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Continue the session
|
||||
req := rpc.ContinueSessionRequest{
|
||||
SessionID: parentSessionID,
|
||||
Query: "follow up question",
|
||||
SystemPrompt: "You are helpful",
|
||||
CustomInstructions: "Be concise",
|
||||
MaxTurns: 3,
|
||||
}
|
||||
|
||||
result, err := sendRPC(t, "continueSession", req)
|
||||
if err != nil {
|
||||
// Expected - Claude binary might not exist in test environment
|
||||
if err.Error() != "failed to continue session: failed to launch resumed Claude session: failed to start claude: exec: \"claude\": executable file not found in $PATH" {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
// Even if Claude fails to launch, we should have created the session
|
||||
return
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var resp rpc.ContinueSessionResponse
|
||||
if err := json.Unmarshal(result, &resp); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
if resp.SessionID == "" {
|
||||
t.Error("Expected session ID in response")
|
||||
}
|
||||
if resp.RunID == "" {
|
||||
t.Error("Expected run ID in response")
|
||||
}
|
||||
if resp.ParentSessionID != parentSessionID {
|
||||
t.Errorf("Expected parent_session_id %s, got %s", parentSessionID, resp.ParentSessionID)
|
||||
}
|
||||
|
||||
// Verify the new session was created with parent reference
|
||||
newSession, err := d.store.GetSession(ctx, resp.SessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
if newSession.ParentSessionID != parentSessionID {
|
||||
t.Errorf("Expected parent_session_id %s, got %s", parentSessionID, newSession.ParentSessionID)
|
||||
}
|
||||
if newSession.Query != "follow up question" {
|
||||
t.Errorf("Expected query 'follow up question', got %s", newSession.Query)
|
||||
}
|
||||
if newSession.SystemPrompt != "You are helpful" {
|
||||
t.Errorf("Expected system prompt override, got %s", newSession.SystemPrompt)
|
||||
}
|
||||
if newSession.CustomInstructions != "Be concise" {
|
||||
t.Errorf("Expected custom instructions override, got %s", newSession.CustomInstructions)
|
||||
}
|
||||
if newSession.MaxTurns != 3 {
|
||||
t.Errorf("Expected max turns 3, got %d", newSession.MaxTurns)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ContinueSession_HandlesOptionalMCPConfig", func(t *testing.T) {
|
||||
// Create parent session
|
||||
parentSessionID := "parent-mcp"
|
||||
parentSession := &store.Session{
|
||||
ID: parentSessionID,
|
||||
RunID: "run-mcp",
|
||||
ClaudeSessionID: "claude-mcp",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
|
||||
if err := d.store.CreateSession(ctx, parentSession); err != nil {
|
||||
t.Fatalf("Failed to create parent session: %v", err)
|
||||
}
|
||||
|
||||
// Create MCP config
|
||||
mcpConfig := map[string]interface{}{
|
||||
"mcpServers": map[string]interface{}{
|
||||
"test-server": map[string]interface{}{
|
||||
"command": "node",
|
||||
"args": []string{"server.js"},
|
||||
"env": map[string]string{
|
||||
"TEST": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpConfigJSON, err := json.Marshal(mcpConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal MCP config: %v", err)
|
||||
}
|
||||
|
||||
// Continue with MCP config
|
||||
req := rpc.ContinueSessionRequest{
|
||||
SessionID: parentSessionID,
|
||||
Query: "with mcp",
|
||||
MCPConfig: string(mcpConfigJSON),
|
||||
}
|
||||
|
||||
_, err = sendRPC(t, "continueSession", req)
|
||||
// Expected to fail (no Claude binary), but session should be created
|
||||
if err != nil && !containsError(err, "failed to launch resumed Claude session") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetConversation_IncludesParentHistory", func(t *testing.T) {
|
||||
// Create a chain of sessions: grandparent -> parent -> child
|
||||
grandparentID := "grandparent"
|
||||
parentID := "parent-chain"
|
||||
childID := "child"
|
||||
|
||||
// Create grandparent session
|
||||
grandparent := &store.Session{
|
||||
ID: grandparentID,
|
||||
RunID: "run-gp",
|
||||
ClaudeSessionID: "claude-gp",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "grandparent query",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
if err := d.store.CreateSession(ctx, grandparent); err != nil {
|
||||
t.Fatalf("Failed to create grandparent: %v", err)
|
||||
}
|
||||
|
||||
// Add grandparent events
|
||||
gpEvents := []*store.ConversationEvent{
|
||||
{
|
||||
SessionID: grandparentID,
|
||||
ClaudeSessionID: "claude-gp",
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "user",
|
||||
Content: "grandparent query",
|
||||
},
|
||||
{
|
||||
SessionID: grandparentID,
|
||||
ClaudeSessionID: "claude-gp",
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "assistant",
|
||||
Content: "grandparent response",
|
||||
},
|
||||
}
|
||||
for _, event := range gpEvents {
|
||||
if err := d.store.AddConversationEvent(ctx, event); err != nil {
|
||||
t.Fatalf("Failed to add grandparent event: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create parent session
|
||||
parent := &store.Session{
|
||||
ID: parentID,
|
||||
RunID: "run-p",
|
||||
ClaudeSessionID: "claude-p",
|
||||
ParentSessionID: grandparentID,
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "parent query",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
if err := d.store.CreateSession(ctx, parent); err != nil {
|
||||
t.Fatalf("Failed to create parent: %v", err)
|
||||
}
|
||||
|
||||
// Add parent events
|
||||
pEvents := []*store.ConversationEvent{
|
||||
{
|
||||
SessionID: parentID,
|
||||
ClaudeSessionID: "claude-p",
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "user",
|
||||
Content: "parent query",
|
||||
},
|
||||
{
|
||||
SessionID: parentID,
|
||||
ClaudeSessionID: "claude-p",
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "assistant",
|
||||
Content: "parent response",
|
||||
},
|
||||
}
|
||||
for _, event := range pEvents {
|
||||
if err := d.store.AddConversationEvent(ctx, event); err != nil {
|
||||
t.Fatalf("Failed to add parent event: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create child session
|
||||
child := &store.Session{
|
||||
ID: childID,
|
||||
RunID: "run-c",
|
||||
ClaudeSessionID: "claude-c",
|
||||
ParentSessionID: parentID,
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "child query",
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
CompletedAt: &time.Time{},
|
||||
}
|
||||
if err := d.store.CreateSession(ctx, child); err != nil {
|
||||
t.Fatalf("Failed to create child: %v", err)
|
||||
}
|
||||
|
||||
// Add child events
|
||||
cEvents := []*store.ConversationEvent{
|
||||
{
|
||||
SessionID: childID,
|
||||
ClaudeSessionID: "claude-c",
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "user",
|
||||
Content: "child query",
|
||||
},
|
||||
{
|
||||
SessionID: childID,
|
||||
ClaudeSessionID: "claude-c",
|
||||
EventType: store.EventTypeMessage,
|
||||
Role: "assistant",
|
||||
Content: "child response",
|
||||
},
|
||||
}
|
||||
for _, event := range cEvents {
|
||||
if err := d.store.AddConversationEvent(ctx, event); err != nil {
|
||||
t.Fatalf("Failed to add child event: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get conversation for child session - should include full history
|
||||
req := rpc.GetConversationRequest{
|
||||
SessionID: childID,
|
||||
}
|
||||
result, err := sendRPC(t, "getConversation", req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get conversation: %v", err)
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var conversation rpc.GetConversationResponse
|
||||
if err := json.Unmarshal(result, &conversation); err != nil {
|
||||
t.Fatalf("Failed to unmarshal conversation: %v", err)
|
||||
}
|
||||
|
||||
// Verify we got all events in correct order
|
||||
if len(conversation.Events) != 6 {
|
||||
t.Errorf("Expected 6 events (2 from each session), got %d", len(conversation.Events))
|
||||
}
|
||||
|
||||
// Verify chronological order
|
||||
expectedContents := []string{
|
||||
"grandparent query",
|
||||
"grandparent response",
|
||||
"parent query",
|
||||
"parent response",
|
||||
"child query",
|
||||
"child response",
|
||||
}
|
||||
|
||||
for i, event := range conversation.Events {
|
||||
if i < len(expectedContents) && event.Content != expectedContents[i] {
|
||||
t.Errorf("Event %d: expected content '%s', got '%s'",
|
||||
i, expectedContents[i], event.Content)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func containsError(err error, substr string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return len(err.Error()) >= len(substr) && contains(err.Error(), substr)
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func (h *SessionHandlers) HandleGetConversation(ctx context.Context, params json
|
||||
// Get conversation by Claude session ID
|
||||
events, err = h.store.GetConversation(ctx, req.ClaudeSessionID)
|
||||
} else {
|
||||
// Get conversation by session ID
|
||||
// Get conversation by session ID - always returns full history including parents
|
||||
events, err = h.store.GetSessionConversation(ctx, req.SessionID)
|
||||
}
|
||||
|
||||
@@ -205,6 +205,7 @@ func (h *SessionHandlers) HandleGetSessionState(ctx context.Context, params json
|
||||
ID: session.ID,
|
||||
RunID: session.RunID,
|
||||
ClaudeSessionID: session.ClaudeSessionID,
|
||||
ParentSessionID: session.ParentSessionID,
|
||||
Status: session.Status,
|
||||
Query: session.Query,
|
||||
Model: session.Model,
|
||||
@@ -233,10 +234,62 @@ func (h *SessionHandlers) HandleGetSessionState(ctx context.Context, params json
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleContinueSession handles the ContinueSession RPC method
|
||||
func (h *SessionHandlers) HandleContinueSession(ctx context.Context, params json.RawMessage) (interface{}, error) {
|
||||
var req ContinueSessionRequest
|
||||
if err := json.Unmarshal(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("invalid request: %w", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.SessionID == "" {
|
||||
return nil, fmt.Errorf("session_id is required")
|
||||
}
|
||||
if req.Query == "" {
|
||||
return nil, fmt.Errorf("query is required")
|
||||
}
|
||||
|
||||
// Build session config for manager
|
||||
config := session.ContinueSessionConfig{
|
||||
ParentSessionID: req.SessionID,
|
||||
Query: req.Query,
|
||||
SystemPrompt: req.SystemPrompt,
|
||||
AppendSystemPrompt: req.AppendSystemPrompt,
|
||||
PermissionPromptTool: req.PermissionPromptTool,
|
||||
AllowedTools: req.AllowedTools,
|
||||
DisallowedTools: req.DisallowedTools,
|
||||
CustomInstructions: req.CustomInstructions,
|
||||
MaxTurns: req.MaxTurns,
|
||||
}
|
||||
|
||||
// Parse MCP config if provided as JSON string
|
||||
if req.MCPConfig != "" {
|
||||
var mcpConfig claudecode.MCPConfig
|
||||
if err := json.Unmarshal([]byte(req.MCPConfig), &mcpConfig); err != nil {
|
||||
return nil, fmt.Errorf("invalid mcp_config JSON: %w", err)
|
||||
}
|
||||
config.MCPConfig = &mcpConfig
|
||||
}
|
||||
|
||||
// Continue session
|
||||
session, err := h.manager.ContinueSession(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ContinueSessionResponse{
|
||||
SessionID: session.ID,
|
||||
RunID: session.RunID,
|
||||
ClaudeSessionID: "", // Will be populated when events stream in
|
||||
ParentSessionID: req.SessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Register registers all session handlers with the RPC server
|
||||
func (h *SessionHandlers) Register(server *Server) {
|
||||
server.Register("launchSession", h.HandleLaunchSession)
|
||||
server.Register("listSessions", h.HandleListSessions)
|
||||
server.Register("getConversation", h.HandleGetConversation)
|
||||
server.Register("getSessionState", h.HandleGetSessionState)
|
||||
server.Register("continueSession", h.HandleContinueSession)
|
||||
}
|
||||
|
||||
248
hld/rpc/handlers_continue_session_test.go
Normal file
248
hld/rpc/handlers_continue_session_test.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/humanlayer/humanlayer/hld/session"
|
||||
"github.com/humanlayer/humanlayer/hld/store"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestHandleContinueSession(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockManager := session.NewMockSessionManager(ctrl)
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
handlers := NewSessionHandlers(mockManager, mockStore)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
request string
|
||||
setupMocks func()
|
||||
expectedError string
|
||||
validateResp func(t *testing.T, resp *ContinueSessionResponse)
|
||||
}{
|
||||
{
|
||||
name: "successful continue session",
|
||||
request: `{
|
||||
"session_id": "parent-123",
|
||||
"query": "follow up question",
|
||||
"system_prompt": "You are helpful",
|
||||
"max_turns": 5
|
||||
}`,
|
||||
setupMocks: func() {
|
||||
mockManager.EXPECT().ContinueSession(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, req session.ContinueSessionConfig) (*session.Session, error) {
|
||||
// Validate request
|
||||
if req.ParentSessionID != "parent-123" {
|
||||
t.Errorf("Expected parent session ID 'parent-123', got %s", req.ParentSessionID)
|
||||
}
|
||||
if req.Query != "follow up question" {
|
||||
t.Errorf("Expected query 'follow up question', got %s", req.Query)
|
||||
}
|
||||
if req.SystemPrompt != "You are helpful" {
|
||||
t.Errorf("Expected system prompt 'You are helpful', got %s", req.SystemPrompt)
|
||||
}
|
||||
if req.MaxTurns != 5 {
|
||||
t.Errorf("Expected max turns 5, got %d", req.MaxTurns)
|
||||
}
|
||||
|
||||
// Return mock session
|
||||
return &session.Session{
|
||||
ID: "child-456",
|
||||
RunID: "run-child",
|
||||
Status: session.StatusRunning,
|
||||
StartTime: time.Now(),
|
||||
}, nil
|
||||
})
|
||||
|
||||
// Note: We don't mock GetSession here because the handler correctly
|
||||
// returns empty claude_session_id (it's not available until events stream)
|
||||
},
|
||||
validateResp: func(t *testing.T, resp *ContinueSessionResponse) {
|
||||
if resp.SessionID != "child-456" {
|
||||
t.Errorf("Expected session ID 'child-456', got %s", resp.SessionID)
|
||||
}
|
||||
if resp.RunID != "run-child" {
|
||||
t.Errorf("Expected run ID 'run-child', got %s", resp.RunID)
|
||||
}
|
||||
// claude_session_id should be empty initially (populated when events stream)
|
||||
if resp.ClaudeSessionID != "" {
|
||||
t.Errorf("Expected empty claude session ID initially, got %s", resp.ClaudeSessionID)
|
||||
}
|
||||
if resp.ParentSessionID != "parent-123" {
|
||||
t.Errorf("Expected parent session ID 'parent-123', got %s", resp.ParentSessionID)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "continue session with MCP config",
|
||||
request: `{
|
||||
"session_id": "parent-mcp",
|
||||
"query": "with mcp",
|
||||
"mcp_config": "{\"mcpServers\": {\"test\": {\"command\": \"node\"}}}"
|
||||
}`,
|
||||
setupMocks: func() {
|
||||
mockManager.EXPECT().ContinueSession(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, req session.ContinueSessionConfig) (*session.Session, error) {
|
||||
// Validate MCP config was parsed
|
||||
if req.MCPConfig == nil {
|
||||
t.Error("Expected MCP config to be parsed")
|
||||
}
|
||||
if req.MCPConfig.MCPServers == nil {
|
||||
t.Error("Expected MCP servers to be set")
|
||||
}
|
||||
if _, ok := req.MCPConfig.MCPServers["test"]; !ok {
|
||||
t.Error("Expected 'test' server in MCP config")
|
||||
}
|
||||
|
||||
return &session.Session{
|
||||
ID: "child-mcp",
|
||||
RunID: "run-mcp",
|
||||
Status: session.StatusRunning,
|
||||
StartTime: time.Now(),
|
||||
}, nil
|
||||
})
|
||||
|
||||
// No GetSession mock needed - claude_session_id won't be available yet
|
||||
},
|
||||
validateResp: func(t *testing.T, resp *ContinueSessionResponse) {
|
||||
if resp.SessionID != "child-mcp" {
|
||||
t.Errorf("Expected session ID 'child-mcp', got %s", resp.SessionID)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing session ID",
|
||||
request: `{"query": "no session"}`,
|
||||
setupMocks: func() {
|
||||
// No mocks needed - validation fails early
|
||||
},
|
||||
expectedError: "session_id is required",
|
||||
},
|
||||
{
|
||||
name: "missing query",
|
||||
request: `{"session_id": "parent-123"}`,
|
||||
setupMocks: func() {
|
||||
// No mocks needed - validation fails early
|
||||
},
|
||||
expectedError: "query is required",
|
||||
},
|
||||
{
|
||||
name: "invalid MCP config JSON",
|
||||
request: `{"session_id": "parent-123", "query": "test", "mcp_config": "invalid json"}`,
|
||||
setupMocks: func() {
|
||||
// No mocks needed - validation fails early
|
||||
},
|
||||
expectedError: "invalid mcp_config JSON",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON request",
|
||||
request: `{invalid json}`,
|
||||
setupMocks: func() {
|
||||
// No mocks needed - parsing fails
|
||||
},
|
||||
expectedError: "invalid request:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.setupMocks()
|
||||
|
||||
resp, err := handlers.HandleContinueSession(context.Background(), []byte(tc.request))
|
||||
|
||||
if tc.expectedError != "" {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
|
||||
} else if !containsStr(err.Error(), tc.expectedError) {
|
||||
t.Errorf("Expected error containing '%s', got '%s'", tc.expectedError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
continueResp, ok := resp.(*ContinueSessionResponse)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *ContinueSessionResponse, got %T", resp)
|
||||
}
|
||||
|
||||
if tc.validateResp != nil {
|
||||
tc.validateResp(t, continueResp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleContinueSession_ToolsConfiguration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockManager := session.NewMockSessionManager(ctrl)
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
handlers := NewSessionHandlers(mockManager, mockStore)
|
||||
|
||||
request := `{
|
||||
"session_id": "parent-tools",
|
||||
"query": "with tools config",
|
||||
"permission_prompt_tool": "mcp__humanlayer__tool",
|
||||
"allowed_tools": ["tool1", "tool2"],
|
||||
"disallowed_tools": ["dangerous_tool"]
|
||||
}`
|
||||
|
||||
mockManager.EXPECT().ContinueSession(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, req session.ContinueSessionConfig) (*session.Session, error) {
|
||||
// Validate tools configuration
|
||||
if req.PermissionPromptTool != "mcp__humanlayer__tool" {
|
||||
t.Errorf("Expected permission prompt tool 'mcp__humanlayer__tool', got %s", req.PermissionPromptTool)
|
||||
}
|
||||
if len(req.AllowedTools) != 2 || req.AllowedTools[0] != "tool1" || req.AllowedTools[1] != "tool2" {
|
||||
t.Errorf("Expected allowed tools [tool1, tool2], got %v", req.AllowedTools)
|
||||
}
|
||||
if len(req.DisallowedTools) != 1 || req.DisallowedTools[0] != "dangerous_tool" {
|
||||
t.Errorf("Expected disallowed tools [dangerous_tool], got %v", req.DisallowedTools)
|
||||
}
|
||||
|
||||
return &session.Session{
|
||||
ID: "child-tools",
|
||||
RunID: "run-tools",
|
||||
Status: session.StatusRunning,
|
||||
StartTime: time.Now(),
|
||||
}, nil
|
||||
})
|
||||
|
||||
// No GetSession mock needed - claude_session_id won't be available yet
|
||||
|
||||
resp, err := handlers.HandleContinueSession(context.Background(), []byte(request))
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
continueResp, ok := resp.(*ContinueSessionResponse)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *ContinueSessionResponse, got %T", resp)
|
||||
}
|
||||
|
||||
if continueResp.SessionID != "child-tools" {
|
||||
t.Errorf("Expected session ID 'child-tools', got %s", continueResp.SessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func containsStr(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && contains(s, substr))
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -58,6 +58,7 @@ type SessionState struct {
|
||||
ID string `json:"id"`
|
||||
RunID string `json:"run_id"`
|
||||
ClaudeSessionID string `json:"claude_session_id,omitempty"`
|
||||
ParentSessionID string `json:"parent_session_id,omitempty"`
|
||||
Status string `json:"status"` // starting, running, completed, failed, waiting_input
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model,omitempty"`
|
||||
@@ -75,3 +76,25 @@ type SessionState struct {
|
||||
type GetSessionStateResponse struct {
|
||||
Session SessionState `json:"session"`
|
||||
}
|
||||
|
||||
// ContinueSessionRequest is the request for continuing an existing session
|
||||
type ContinueSessionRequest struct {
|
||||
SessionID string `json:"session_id"` // The session to continue (required)
|
||||
Query string `json:"query"` // The new query/message to send (required)
|
||||
SystemPrompt string `json:"system_prompt,omitempty"` // Override system prompt
|
||||
AppendSystemPrompt string `json:"append_system_prompt,omitempty"` // Append to system prompt
|
||||
MCPConfig string `json:"mcp_config,omitempty"` // JSON string of MCP config (to avoid import cycle)
|
||||
PermissionPromptTool string `json:"permission_prompt_tool,omitempty"` // MCP tool for permission prompts
|
||||
AllowedTools []string `json:"allowed_tools,omitempty"` // Allowed tools list
|
||||
DisallowedTools []string `json:"disallowed_tools,omitempty"` // Disallowed tools list
|
||||
CustomInstructions string `json:"custom_instructions,omitempty"` // Custom instructions
|
||||
MaxTurns int `json:"max_turns,omitempty"` // Max conversation turns
|
||||
}
|
||||
|
||||
// ContinueSessionResponse is the response for continuing a session
|
||||
type ContinueSessionResponse struct {
|
||||
SessionID string `json:"session_id"` // The new session ID
|
||||
RunID string `json:"run_id"` // The new run ID
|
||||
ClaudeSessionID string `json:"claude_session_id"` // The new Claude session ID (unique for each resume)
|
||||
ParentSessionID string `json:"parent_session_id"` // The parent session ID
|
||||
}
|
||||
|
||||
@@ -437,3 +437,143 @@ func (m *Manager) processStreamEvent(ctx context.Context, sessionID string, clau
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContinueSession resumes an existing completed session with a new query and optional config overrides
|
||||
func (m *Manager) ContinueSession(ctx context.Context, req ContinueSessionConfig) (*Session, error) {
|
||||
// Get parent session from database
|
||||
parentSession, err := m.store.GetSession(ctx, req.ParentSessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get parent session: %w", err)
|
||||
}
|
||||
|
||||
// Validate parent session status
|
||||
if parentSession.Status != store.SessionStatusCompleted {
|
||||
return nil, fmt.Errorf("cannot continue session with status %s (must be completed)", parentSession.Status)
|
||||
}
|
||||
|
||||
// Validate parent session has claude_session_id
|
||||
if parentSession.ClaudeSessionID == "" {
|
||||
return nil, fmt.Errorf("parent session missing claude_session_id (cannot resume)")
|
||||
}
|
||||
|
||||
// Build config for resumed session
|
||||
// Start with minimal required fields
|
||||
config := claudecode.SessionConfig{
|
||||
Query: req.Query,
|
||||
SessionID: parentSession.ClaudeSessionID, // This triggers --resume flag
|
||||
OutputFormat: claudecode.OutputStreamJSON, // Always use streaming JSON
|
||||
// Model and WorkingDir are inherited from Claude's internal state
|
||||
}
|
||||
|
||||
// Apply optional overrides
|
||||
if req.SystemPrompt != "" {
|
||||
config.SystemPrompt = req.SystemPrompt
|
||||
}
|
||||
if req.AppendSystemPrompt != "" {
|
||||
config.AppendSystemPrompt = req.AppendSystemPrompt
|
||||
}
|
||||
if req.MCPConfig != nil {
|
||||
config.MCPConfig = req.MCPConfig
|
||||
}
|
||||
if req.PermissionPromptTool != "" {
|
||||
config.PermissionPromptTool = req.PermissionPromptTool
|
||||
}
|
||||
if len(req.AllowedTools) > 0 {
|
||||
config.AllowedTools = req.AllowedTools
|
||||
}
|
||||
if len(req.DisallowedTools) > 0 {
|
||||
config.DisallowedTools = req.DisallowedTools
|
||||
}
|
||||
if req.CustomInstructions != "" {
|
||||
config.CustomInstructions = req.CustomInstructions
|
||||
}
|
||||
if req.MaxTurns > 0 {
|
||||
config.MaxTurns = req.MaxTurns
|
||||
}
|
||||
|
||||
// Create new session with parent reference
|
||||
sessionID := uuid.New().String()
|
||||
runID := uuid.New().String()
|
||||
|
||||
// Store session in database with parent reference
|
||||
dbSession := store.NewSessionFromConfig(sessionID, runID, config)
|
||||
dbSession.ParentSessionID = req.ParentSessionID
|
||||
// Note: ClaudeSessionID will be captured from streaming events (will be different from parent)
|
||||
if err := m.store.CreateSession(ctx, dbSession); err != nil {
|
||||
return nil, fmt.Errorf("failed to store session in database: %w", err)
|
||||
}
|
||||
|
||||
// Add run_id to MCP server environments
|
||||
if config.MCPConfig != nil {
|
||||
for name, server := range config.MCPConfig.MCPServers {
|
||||
if server.Env == nil {
|
||||
server.Env = make(map[string]string)
|
||||
}
|
||||
server.Env["HUMANLAYER_RUN_ID"] = runID
|
||||
config.MCPConfig.MCPServers[name] = server
|
||||
}
|
||||
|
||||
// Store MCP servers configuration
|
||||
servers, err := store.MCPServersFromConfig(sessionID, config.MCPConfig.MCPServers)
|
||||
if err != nil {
|
||||
slog.Error("failed to convert MCP servers", "error", err)
|
||||
} else if err := m.store.StoreMCPServers(ctx, sessionID, servers); err != nil {
|
||||
slog.Error("failed to store MCP servers", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Launch resumed Claude session
|
||||
claudeSession, err := m.client.Launch(config)
|
||||
if err != nil {
|
||||
m.updateSessionStatus(ctx, sessionID, StatusFailed, err.Error())
|
||||
return nil, fmt.Errorf("failed to launch resumed Claude session: %w", err)
|
||||
}
|
||||
|
||||
// Store active Claude process
|
||||
m.mu.Lock()
|
||||
m.activeProcesses[sessionID] = claudeSession
|
||||
m.mu.Unlock()
|
||||
|
||||
// Update database with running status
|
||||
statusRunning := string(StatusRunning)
|
||||
now := time.Now()
|
||||
update := store.SessionUpdate{
|
||||
Status: &statusRunning,
|
||||
LastActivityAt: &now,
|
||||
}
|
||||
if err := m.store.UpdateSession(ctx, sessionID, update); err != nil {
|
||||
slog.Error("failed to update session status to running", "error", err)
|
||||
}
|
||||
|
||||
// Publish status change event
|
||||
if m.eventBus != nil {
|
||||
m.eventBus.Publish(bus.Event{
|
||||
Type: bus.EventSessionStatusChanged,
|
||||
Data: map[string]interface{}{
|
||||
"session_id": sessionID,
|
||||
"run_id": runID,
|
||||
"parent_session_id": req.ParentSessionID,
|
||||
"old_status": string(StatusStarting),
|
||||
"new_status": string(StatusRunning),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Monitor session lifecycle in background
|
||||
go m.monitorSession(ctx, sessionID, runID, claudeSession, time.Now(), config)
|
||||
|
||||
slog.Info("continued Claude session",
|
||||
"session_id", sessionID,
|
||||
"parent_session_id", req.ParentSessionID,
|
||||
"run_id", runID,
|
||||
"query", req.Query)
|
||||
|
||||
// Return minimal session info
|
||||
return &Session{
|
||||
ID: sessionID,
|
||||
RunID: runID,
|
||||
Status: StatusRunning,
|
||||
StartTime: time.Now(),
|
||||
Config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -114,3 +115,280 @@ func TestGetSessionInfo(t *testing.T) {
|
||||
// Note: Most of the old tests were removed because they tested internal implementation
|
||||
// details (in-memory maps) that no longer exist. The real functionality is now
|
||||
// tested by the integration tests which use actual SQLite database.
|
||||
|
||||
func TestContinueSession_ValidatesParentExists(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
manager, _ := NewManager(nil, mockStore)
|
||||
|
||||
// Test parent not found
|
||||
mockStore.EXPECT().GetSession(gomock.Any(), "not-found").Return(nil, fmt.Errorf("session not found"))
|
||||
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: "not-found",
|
||||
Query: "continue this",
|
||||
}
|
||||
_, err := manager.ContinueSession(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent parent session")
|
||||
}
|
||||
if err.Error() != "failed to get parent session: session not found" {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinueSession_ValidatesParentStatus(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
manager, _ := NewManager(nil, mockStore)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
parentStatus string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "running session",
|
||||
parentStatus: store.SessionStatusRunning,
|
||||
expectedError: "cannot continue session with status running (must be completed)",
|
||||
},
|
||||
{
|
||||
name: "failed session",
|
||||
parentStatus: store.SessionStatusFailed,
|
||||
expectedError: "cannot continue session with status failed (must be completed)",
|
||||
},
|
||||
{
|
||||
name: "starting session",
|
||||
parentStatus: store.SessionStatusStarting,
|
||||
expectedError: "cannot continue session with status starting (must be completed)",
|
||||
},
|
||||
{
|
||||
name: "waiting input session",
|
||||
parentStatus: store.SessionStatusWaitingInput,
|
||||
expectedError: "cannot continue session with status waiting_input (must be completed)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
parentSession := &store.Session{
|
||||
ID: "parent-1",
|
||||
RunID: "run-1",
|
||||
ClaudeSessionID: "claude-1",
|
||||
Status: tc.parentStatus,
|
||||
Query: "original query",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
|
||||
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: "parent-1",
|
||||
Query: "continue this",
|
||||
}
|
||||
_, err := manager.ContinueSession(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-completed parent session")
|
||||
}
|
||||
if err.Error() != tc.expectedError {
|
||||
t.Errorf("Expected error '%s', got: %v", tc.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinueSession_ValidatesClaudeSessionID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
manager, _ := NewManager(nil, mockStore)
|
||||
|
||||
// Parent without claude_session_id
|
||||
parentSession := &store.Session{
|
||||
ID: "parent-1",
|
||||
RunID: "run-1",
|
||||
ClaudeSessionID: "", // Empty
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original query",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
|
||||
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: "parent-1",
|
||||
Query: "continue this",
|
||||
}
|
||||
_, err := manager.ContinueSession(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Error("Expected error for parent without claude_session_id")
|
||||
}
|
||||
if err.Error() != "parent session missing claude_session_id (cannot resume)" {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinueSession_CreatesNewSessionWithParentReference(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
manager, _ := NewManager(nil, mockStore)
|
||||
|
||||
// Mock parent session
|
||||
parentSession := &store.Session{
|
||||
ID: "parent-1",
|
||||
RunID: "run-1",
|
||||
ClaudeSessionID: "claude-1",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original query",
|
||||
Model: "claude-3-opus",
|
||||
WorkingDir: "/test/dir",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
|
||||
|
||||
// Expect session creation with parent reference
|
||||
mockStore.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx interface{}, session *store.Session) error {
|
||||
// Validate the created session
|
||||
if session.ParentSessionID != "parent-1" {
|
||||
t.Errorf("Expected parent_session_id to be 'parent-1', got '%s'", session.ParentSessionID)
|
||||
}
|
||||
if session.Query != "continue this" {
|
||||
t.Errorf("Expected query 'continue this', got '%s'", session.Query)
|
||||
}
|
||||
if session.Status != store.SessionStatusStarting {
|
||||
t.Errorf("Expected status 'starting', got '%s'", session.Status)
|
||||
}
|
||||
// Should not inherit claude_session_id (will be set from streaming events)
|
||||
if session.ClaudeSessionID != "" {
|
||||
t.Errorf("Expected empty claude_session_id, got '%s'", session.ClaudeSessionID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Expect status update to running (we can't test the full flow without mocking Claude client)
|
||||
// May be called twice if Claude fails to launch in background
|
||||
mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: "parent-1",
|
||||
Query: "continue this",
|
||||
}
|
||||
|
||||
// Try to continue session - this tests our logic, not Claude launch
|
||||
session, err := manager.ContinueSession(context.Background(), req)
|
||||
|
||||
// If Claude binary exists, it might succeed; if not, it will fail
|
||||
// Either way, our mock expectations should have been met (session created with parent)
|
||||
if err != nil {
|
||||
// Expected - Claude binary might not exist in test environment
|
||||
if !containsError(err, "failed to launch resumed Claude session") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Claude launched successfully - verify session has expected properties
|
||||
if session.ID == "" {
|
||||
t.Error("Expected session ID to be set")
|
||||
}
|
||||
if session.RunID == "" {
|
||||
t.Error("Expected run ID to be set")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinueSession_HandlesOptionalOverrides(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockConversationStore(ctrl)
|
||||
manager, _ := NewManager(nil, mockStore)
|
||||
|
||||
// Mock parent session
|
||||
parentSession := &store.Session{
|
||||
ID: "parent-1",
|
||||
RunID: "run-1",
|
||||
ClaudeSessionID: "claude-1",
|
||||
Status: store.SessionStatusCompleted,
|
||||
Query: "original query",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockStore.EXPECT().GetSession(gomock.Any(), "parent-1").Return(parentSession, nil)
|
||||
|
||||
// Test with various overrides
|
||||
req := ContinueSessionConfig{
|
||||
ParentSessionID: "parent-1",
|
||||
Query: "continue with overrides",
|
||||
SystemPrompt: "You are a pirate",
|
||||
AppendSystemPrompt: "Always say arr",
|
||||
PermissionPromptTool: "mcp__custom__tool",
|
||||
AllowedTools: []string{"tool1", "tool2"},
|
||||
DisallowedTools: []string{"dangerous_tool"},
|
||||
CustomInstructions: "Be helpful",
|
||||
MaxTurns: 5,
|
||||
}
|
||||
|
||||
// Expect session creation
|
||||
mockStore.EXPECT().CreateSession(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx interface{}, session *store.Session) error {
|
||||
// Validate overrides are stored
|
||||
if session.SystemPrompt != "You are a pirate" {
|
||||
t.Errorf("Expected system prompt override, got '%s'", session.SystemPrompt)
|
||||
}
|
||||
if session.CustomInstructions != "Be helpful" {
|
||||
t.Errorf("Expected custom instructions override, got '%s'", session.CustomInstructions)
|
||||
}
|
||||
if session.MaxTurns != 5 {
|
||||
t.Errorf("Expected max turns 5, got %d", session.MaxTurns)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Expect status update
|
||||
// May be called twice if Claude fails to launch in background
|
||||
mockStore.EXPECT().UpdateSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
session, err := manager.ContinueSession(context.Background(), req)
|
||||
|
||||
// Test passes if our mock expectations were met (session created with overrides)
|
||||
// Whether Claude actually launches depends on the environment
|
||||
if err != nil {
|
||||
// Expected - Claude binary might not exist
|
||||
if !containsError(err, "failed to launch resumed Claude session") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Claude launched - verify session properties
|
||||
if session.ID == "" {
|
||||
t.Error("Expected session ID to be set")
|
||||
}
|
||||
if session.RunID == "" {
|
||||
t.Error("Expected run ID to be set")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if error contains a substring
|
||||
func containsError(err error, substr string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return contains(err.Error(), substr)
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsStr(s, substr))
|
||||
}
|
||||
|
||||
func containsStr(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -42,11 +42,28 @@ type Info struct {
|
||||
Result *claudecode.Result `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// ContinueSessionConfig contains the configuration for continuing a session
|
||||
type ContinueSessionConfig struct {
|
||||
ParentSessionID string // The parent session to resume from
|
||||
Query string // The new query
|
||||
SystemPrompt string // Optional system prompt override
|
||||
AppendSystemPrompt string // Optional append to system prompt
|
||||
MCPConfig *claudecode.MCPConfig // Optional MCP config override
|
||||
PermissionPromptTool string // Optional permission prompt tool
|
||||
AllowedTools []string // Optional allowed tools override
|
||||
DisallowedTools []string // Optional disallowed tools override
|
||||
CustomInstructions string // Optional custom instructions
|
||||
MaxTurns int // Optional max turns override
|
||||
}
|
||||
|
||||
// SessionManager defines the interface for managing Claude Code sessions
|
||||
type SessionManager interface {
|
||||
// LaunchSession starts a new Claude Code session
|
||||
LaunchSession(ctx context.Context, config claudecode.SessionConfig) (*Session, error)
|
||||
|
||||
// ContinueSession resumes an existing completed session with a new query and optional config overrides
|
||||
ContinueSession(ctx context.Context, req ContinueSessionConfig) (*Session, error)
|
||||
|
||||
// GetSessionInfo returns session info from the database by ID
|
||||
GetSessionInfo(sessionID string) (*Info, error)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
claudecode "github.com/humanlayer/humanlayer/claudecode-go"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
@@ -523,24 +524,98 @@ func (s *SQLiteStore) GetConversation(ctx context.Context, claudeSessionID strin
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// GetSessionConversation retrieves all events for a session
|
||||
// GetSessionConversation retrieves all events for a session including parent history
|
||||
func (s *SQLiteStore) GetSessionConversation(ctx context.Context, sessionID string) ([]*ConversationEvent, error) {
|
||||
// First get the claude_session_id for this session
|
||||
// Walk up the parent chain to get all related claude session IDs
|
||||
claudeSessionIDs := []string{}
|
||||
currentID := sessionID
|
||||
|
||||
for currentID != "" {
|
||||
var claudeSessionID sql.NullString
|
||||
var parentID sql.NullString
|
||||
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
"SELECT claude_session_id FROM sessions WHERE id = ?",
|
||||
sessionID,
|
||||
).Scan(&claudeSessionID)
|
||||
"SELECT claude_session_id, parent_session_id FROM sessions WHERE id = ?",
|
||||
currentID,
|
||||
).Scan(&claudeSessionID, &parentID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
break // Session not found, stop walking
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
if !claudeSessionID.Valid {
|
||||
// No claude session yet, return empty
|
||||
// Add claude session ID if present (in reverse order for chronological events)
|
||||
if claudeSessionID.Valid && claudeSessionID.String != "" {
|
||||
claudeSessionIDs = append([]string{claudeSessionID.String}, claudeSessionIDs...)
|
||||
}
|
||||
|
||||
// Move to parent
|
||||
if parentID.Valid {
|
||||
currentID = parentID.String
|
||||
} else {
|
||||
currentID = ""
|
||||
}
|
||||
}
|
||||
|
||||
if len(claudeSessionIDs) == 0 {
|
||||
// No claude sessions yet, return empty
|
||||
return []*ConversationEvent{}, nil
|
||||
}
|
||||
|
||||
return s.GetConversation(ctx, claudeSessionID.String)
|
||||
// Get all events for all claude session IDs in chronological order
|
||||
placeholders := make([]string, len(claudeSessionIDs))
|
||||
args := make([]interface{}, len(claudeSessionIDs))
|
||||
for i, id := range claudeSessionIDs {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
// Build query that orders by the position in the claude session ID list first
|
||||
// This ensures parent events come before child events
|
||||
orderCases := make([]string, len(claudeSessionIDs))
|
||||
for i := range claudeSessionIDs {
|
||||
orderCases[i] = fmt.Sprintf("WHEN claude_session_id = ? THEN %d", i)
|
||||
args = append(args, claudeSessionIDs[i])
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, session_id, claude_session_id, sequence, event_type, created_at,
|
||||
role, content,
|
||||
tool_id, tool_name, tool_input_json,
|
||||
tool_result_for_id, tool_result_content,
|
||||
is_completed, approval_status, approval_id
|
||||
FROM conversation_events
|
||||
WHERE claude_session_id IN (%s)
|
||||
ORDER BY
|
||||
CASE %s END,
|
||||
sequence
|
||||
`, strings.Join(placeholders, ","), strings.Join(orderCases, " "))
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var events []*ConversationEvent
|
||||
for rows.Next() {
|
||||
event := &ConversationEvent{}
|
||||
err := rows.Scan(
|
||||
&event.ID, &event.SessionID, &event.ClaudeSessionID,
|
||||
&event.Sequence, &event.EventType, &event.CreatedAt,
|
||||
&event.Role, &event.Content,
|
||||
&event.ToolID, &event.ToolName, &event.ToolInputJSON,
|
||||
&event.ToolResultForID, &event.ToolResultContent,
|
||||
&event.IsCompleted, &event.ApprovalStatus, &event.ApprovalID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan event: %w", err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// GetPendingToolCall finds the most recent uncompleted tool call for a given session and tool name
|
||||
|
||||
@@ -262,3 +262,203 @@ func TestSQLiteStore(t *testing.T) {
|
||||
require.Len(t, sessions, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetSessionConversationWithParentChain(t *testing.T) {
|
||||
// Create temp database
|
||||
tmpDir, err := os.MkdirTemp("", "hld-test-parent-*")
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = os.RemoveAll(tmpDir) }()
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := NewSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create parent session
|
||||
parentSession := &Session{
|
||||
ID: "parent-1",
|
||||
RunID: "run-1",
|
||||
ClaudeSessionID: "claude-parent",
|
||||
Query: "Tell me about Go",
|
||||
Status: SessionStatusCompleted,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
}
|
||||
err = store.CreateSession(ctx, parentSession)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add events to parent session
|
||||
parentEvents := []*ConversationEvent{
|
||||
{
|
||||
SessionID: "parent-1",
|
||||
ClaudeSessionID: "claude-parent",
|
||||
Sequence: 1,
|
||||
EventType: EventTypeMessage,
|
||||
Role: "user",
|
||||
Content: "Tell me about Go",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
SessionID: "parent-1",
|
||||
ClaudeSessionID: "claude-parent",
|
||||
Sequence: 2,
|
||||
EventType: EventTypeMessage,
|
||||
Role: "assistant",
|
||||
Content: "Go is a statically typed programming language...",
|
||||
CreatedAt: time.Now().Add(1 * time.Second),
|
||||
},
|
||||
}
|
||||
for _, event := range parentEvents {
|
||||
err = store.AddConversationEvent(ctx, event)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create child session
|
||||
childSession := &Session{
|
||||
ID: "child-1",
|
||||
RunID: "run-2",
|
||||
ClaudeSessionID: "claude-child",
|
||||
ParentSessionID: "parent-1",
|
||||
Query: "Tell me more about goroutines",
|
||||
Status: SessionStatusCompleted,
|
||||
CreatedAt: time.Now().Add(5 * time.Second),
|
||||
LastActivityAt: time.Now().Add(5 * time.Second),
|
||||
}
|
||||
err = store.CreateSession(ctx, childSession)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add events to child session
|
||||
childEvents := []*ConversationEvent{
|
||||
{
|
||||
SessionID: "child-1",
|
||||
ClaudeSessionID: "claude-child",
|
||||
Sequence: 1,
|
||||
EventType: EventTypeMessage,
|
||||
Role: "user",
|
||||
Content: "Tell me more about goroutines",
|
||||
CreatedAt: time.Now().Add(10 * time.Second),
|
||||
},
|
||||
{
|
||||
SessionID: "child-1",
|
||||
ClaudeSessionID: "claude-child",
|
||||
Sequence: 2,
|
||||
EventType: EventTypeMessage,
|
||||
Role: "assistant",
|
||||
Content: "Goroutines are lightweight threads...",
|
||||
CreatedAt: time.Now().Add(11 * time.Second),
|
||||
},
|
||||
}
|
||||
for _, event := range childEvents {
|
||||
err = store.AddConversationEvent(ctx, event)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create grandchild session
|
||||
grandchildSession := &Session{
|
||||
ID: "grandchild-1",
|
||||
RunID: "run-3",
|
||||
ClaudeSessionID: "claude-grandchild",
|
||||
ParentSessionID: "child-1",
|
||||
Query: "How do channels work?",
|
||||
Status: SessionStatusRunning,
|
||||
CreatedAt: time.Now().Add(20 * time.Second),
|
||||
LastActivityAt: time.Now().Add(20 * time.Second),
|
||||
}
|
||||
err = store.CreateSession(ctx, grandchildSession)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add events to grandchild session
|
||||
grandchildEvents := []*ConversationEvent{
|
||||
{
|
||||
SessionID: "grandchild-1",
|
||||
ClaudeSessionID: "claude-grandchild",
|
||||
Sequence: 1,
|
||||
EventType: EventTypeMessage,
|
||||
Role: "user",
|
||||
Content: "How do channels work?",
|
||||
CreatedAt: time.Now().Add(25 * time.Second),
|
||||
},
|
||||
{
|
||||
SessionID: "grandchild-1",
|
||||
ClaudeSessionID: "claude-grandchild",
|
||||
Sequence: 2,
|
||||
EventType: EventTypeMessage,
|
||||
Role: "assistant",
|
||||
Content: "Channels are Go's way of communication...",
|
||||
CreatedAt: time.Now().Add(26 * time.Second),
|
||||
},
|
||||
}
|
||||
for _, event := range grandchildEvents {
|
||||
err = store.AddConversationEvent(ctx, event)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("GetSessionConversation_IncludesFullHistory", func(t *testing.T) {
|
||||
// Get conversation for grandchild - should include all parent events
|
||||
events, err := store.GetSessionConversation(ctx, "grandchild-1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 6) // 2 from parent + 2 from child + 2 from grandchild
|
||||
|
||||
// Verify chronological order
|
||||
expectedContents := []string{
|
||||
"Tell me about Go",
|
||||
"Go is a statically typed programming language...",
|
||||
"Tell me more about goroutines",
|
||||
"Goroutines are lightweight threads...",
|
||||
"How do channels work?",
|
||||
"Channels are Go's way of communication...",
|
||||
}
|
||||
|
||||
for i, event := range events {
|
||||
require.Equal(t, expectedContents[i], event.Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetSessionConversation_SessionWithoutClaudeID", func(t *testing.T) {
|
||||
// Create session without claude_session_id yet
|
||||
newSession := &Session{
|
||||
ID: "new-session",
|
||||
RunID: "run-4",
|
||||
Query: "New query",
|
||||
Status: SessionStatusStarting,
|
||||
CreatedAt: time.Now(),
|
||||
LastActivityAt: time.Now(),
|
||||
}
|
||||
err = store.CreateSession(ctx, newSession)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return empty events
|
||||
events, err := store.GetSessionConversation(ctx, "new-session")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 0)
|
||||
})
|
||||
|
||||
t.Run("GetSessionConversation_NonExistentSession", func(t *testing.T) {
|
||||
// Should return empty events for non-existent session
|
||||
events, err := store.GetSessionConversation(ctx, "does-not-exist")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 0)
|
||||
})
|
||||
|
||||
t.Run("GetSessionConversation_NoParent", func(t *testing.T) {
|
||||
// Get conversation for parent session (no parents)
|
||||
events, err := store.GetSessionConversation(ctx, "parent-1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 2) // Only parent's events
|
||||
})
|
||||
|
||||
t.Run("GetSessionConversation_MiddleOfChain", func(t *testing.T) {
|
||||
// Get conversation for child (middle of chain)
|
||||
events, err := store.GetSessionConversation(ctx, "child-1")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 4) // Parent's 2 + child's 2
|
||||
|
||||
// Verify we get parent events first, then child
|
||||
require.Equal(t, "Tell me about Go", events[0].Content)
|
||||
require.Equal(t, "Go is a statically typed programming language...", events[1].Content)
|
||||
require.Equal(t, "Tell me more about goroutines", events[2].Content)
|
||||
require.Equal(t, "Goroutines are lightweight threads...", events[3].Content)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user