mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
Add RFC and implementation of import and stats
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -12,6 +15,7 @@ import (
|
||||
"github.com/charmbracelet/crush/internal/db"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/session"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -22,6 +26,89 @@ type SessionWithChildren struct {
|
||||
Children []SessionWithChildren `json:"children,omitempty" yaml:"children,omitempty"`
|
||||
}
|
||||
|
||||
// ImportSession represents a session with proper JSON tags for import
|
||||
type ImportSession struct {
|
||||
ID string `json:"id"`
|
||||
ParentSessionID string `json:"parent_session_id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
SummaryMessageID string `json:"summary_message_id,omitempty"`
|
||||
Children []ImportSession `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
// ImportData represents the full import structure for sessions
|
||||
type ImportData struct {
|
||||
Version string `json:"version" yaml:"version"`
|
||||
ExportedAt string `json:"exported_at,omitempty" yaml:"exported_at,omitempty"`
|
||||
TotalSessions int `json:"total_sessions,omitempty" yaml:"total_sessions,omitempty"`
|
||||
Sessions []ImportSession `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
// ImportMessage represents a message with proper JSON tags for import
|
||||
type ImportMessage struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
SessionID string `json:"session_id"`
|
||||
Parts []interface{} `json:"parts"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// ImportSessionInfo represents session info with proper JSON tags for conversation import
|
||||
type ImportSessionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens,omitempty"`
|
||||
CompletionTokens int64 `json:"completion_tokens,omitempty"`
|
||||
Cost float64 `json:"cost,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// ConversationData represents a single conversation import structure
|
||||
type ConversationData struct {
|
||||
Version string `json:"version" yaml:"version"`
|
||||
Session ImportSessionInfo `json:"session" yaml:"session"`
|
||||
Messages []ImportMessage `json:"messages" yaml:"messages"`
|
||||
}
|
||||
|
||||
// ImportResult contains the results of an import operation
|
||||
type ImportResult struct {
|
||||
TotalSessions int `json:"total_sessions"`
|
||||
ImportedSessions int `json:"imported_sessions"`
|
||||
SkippedSessions int `json:"skipped_sessions"`
|
||||
ImportedMessages int `json:"imported_messages"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id
|
||||
}
|
||||
|
||||
// SessionStats represents aggregated session statistics
|
||||
type SessionStats struct {
|
||||
TotalSessions int64 `json:"total_sessions"`
|
||||
TotalMessages int64 `json:"total_messages"`
|
||||
TotalPromptTokens int64 `json:"total_prompt_tokens"`
|
||||
TotalCompletionTokens int64 `json:"total_completion_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
AvgCostPerSession float64 `json:"avg_cost_per_session"`
|
||||
}
|
||||
|
||||
// GroupedSessionStats represents statistics grouped by time period
|
||||
type GroupedSessionStats struct {
|
||||
Period string `json:"period"`
|
||||
SessionCount int64 `json:"session_count"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
AvgCost float64 `json:"avg_cost"`
|
||||
}
|
||||
|
||||
var sessionsCmd = &cobra.Command{
|
||||
Use: "sessions",
|
||||
Short: "Manage sessions",
|
||||
@@ -60,15 +147,80 @@ var exportConversationCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
var importCmd = &cobra.Command{
|
||||
Use: "import <file>",
|
||||
Short: "Import sessions from a file",
|
||||
Long: `Import sessions from a JSON or YAML file with hierarchical structure`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
file := args[0]
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
dryRun, _ := cmd.Flags().GetBool("dry-run")
|
||||
return runImport(cmd.Context(), file, format, dryRun)
|
||||
},
|
||||
}
|
||||
|
||||
var importConversationCmd = &cobra.Command{
|
||||
Use: "import-conversation <file>",
|
||||
Short: "Import a single conversation from a file",
|
||||
Long: `Import a single conversation with messages from a JSON, YAML, or Markdown file`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
file := args[0]
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
return runImportConversation(cmd.Context(), file, format)
|
||||
},
|
||||
}
|
||||
|
||||
var searchCmd = &cobra.Command{
|
||||
Use: "search",
|
||||
Short: "Search sessions by title or message content",
|
||||
Long: `Search sessions by title pattern (case-insensitive) or message text content`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
titlePattern, _ := cmd.Flags().GetString("title")
|
||||
textPattern, _ := cmd.Flags().GetString("text")
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
|
||||
if titlePattern == "" && textPattern == "" {
|
||||
return fmt.Errorf("at least one of --title or --text must be provided")
|
||||
}
|
||||
|
||||
return runSessionsSearch(cmd.Context(), titlePattern, textPattern, format)
|
||||
},
|
||||
}
|
||||
|
||||
var statsCmd = &cobra.Command{
|
||||
Use: "stats",
|
||||
Short: "Show session statistics",
|
||||
Long: `Display aggregated statistics about sessions including total counts, tokens, and costs`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
format, _ := cmd.Flags().GetString("format")
|
||||
groupBy, _ := cmd.Flags().GetString("group-by")
|
||||
return runSessionsStats(cmd.Context(), format, groupBy)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(sessionsCmd)
|
||||
sessionsCmd.AddCommand(listCmd)
|
||||
sessionsCmd.AddCommand(exportCmd)
|
||||
sessionsCmd.AddCommand(exportConversationCmd)
|
||||
sessionsCmd.AddCommand(importCmd)
|
||||
sessionsCmd.AddCommand(importConversationCmd)
|
||||
sessionsCmd.AddCommand(searchCmd)
|
||||
sessionsCmd.AddCommand(statsCmd)
|
||||
|
||||
listCmd.Flags().StringP("format", "f", "text", "Output format (text, json, yaml, markdown)")
|
||||
exportCmd.Flags().StringP("format", "f", "json", "Export format (json, yaml, markdown)")
|
||||
exportConversationCmd.Flags().StringP("format", "f", "markdown", "Export format (markdown, json, yaml)")
|
||||
importCmd.Flags().StringP("format", "f", "", "Import format (json, yaml) - auto-detected if not specified")
|
||||
importCmd.Flags().Bool("dry-run", false, "Validate import data without persisting changes")
|
||||
importConversationCmd.Flags().StringP("format", "f", "", "Import format (json, yaml, markdown) - auto-detected if not specified")
|
||||
searchCmd.Flags().String("title", "", "Search by session title pattern (case-insensitive substring search)")
|
||||
searchCmd.Flags().String("text", "", "Search by message text content")
|
||||
searchCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
|
||||
statsCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
|
||||
statsCmd.Flags().String("group-by", "", "Group statistics by time period (day, week, month)")
|
||||
}
|
||||
|
||||
func runSessionsList(ctx context.Context, format string) error {
|
||||
@@ -442,3 +594,773 @@ func formatConversationYAML(sess session.Session, messages []message.Message) er
|
||||
fmt.Println(string(yamlData))
|
||||
return nil
|
||||
}
|
||||
|
||||
func runImport(ctx context.Context, file, format string, dryRun bool) error {
|
||||
// Read the file
|
||||
data, err := readImportFile(file, format)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read import file: %w", err)
|
||||
}
|
||||
|
||||
// Validate the data structure
|
||||
if err := validateImportData(data); err != nil {
|
||||
return fmt.Errorf("invalid import data: %w", err)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
result := ImportResult{
|
||||
TotalSessions: countTotalImportSessions(data.Sessions),
|
||||
ImportedSessions: 0,
|
||||
SkippedSessions: 0,
|
||||
ImportedMessages: 0,
|
||||
SessionMapping: make(map[string]string),
|
||||
}
|
||||
fmt.Printf("Dry run: Would import %d sessions\n", result.TotalSessions)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Perform the actual import
|
||||
sessionService, messageService, err := createServices(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := importSessions(ctx, sessionService, messageService, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("import failed: %w", err)
|
||||
}
|
||||
|
||||
// Print summary
|
||||
fmt.Printf("Import completed successfully:\n")
|
||||
fmt.Printf(" Total sessions processed: %d\n", result.TotalSessions)
|
||||
fmt.Printf(" Sessions imported: %d\n", result.ImportedSessions)
|
||||
fmt.Printf(" Sessions skipped: %d\n", result.SkippedSessions)
|
||||
fmt.Printf(" Messages imported: %d\n", result.ImportedMessages)
|
||||
|
||||
if len(result.Errors) > 0 {
|
||||
fmt.Printf(" Errors encountered: %d\n", len(result.Errors))
|
||||
for _, errStr := range result.Errors {
|
||||
fmt.Printf(" - %s\n", errStr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runImportConversation(ctx context.Context, file, format string) error {
|
||||
// Read the conversation file
|
||||
convData, err := readConversationFile(file, format)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read conversation file: %w", err)
|
||||
}
|
||||
|
||||
// Validate the conversation data
|
||||
if err := validateConversationData(convData); err != nil {
|
||||
return fmt.Errorf("invalid conversation data: %w", err)
|
||||
}
|
||||
|
||||
// Import the conversation
|
||||
sessionService, messageService, err := createServices(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newSessionID, messageCount, err := importConversation(ctx, sessionService, messageService, convData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("conversation import failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Conversation imported successfully:\n")
|
||||
fmt.Printf(" Session ID: %s\n", newSessionID)
|
||||
fmt.Printf(" Title: %s\n", convData.Session.Title)
|
||||
fmt.Printf(" Messages imported: %d\n", messageCount)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readImportFile(file, format string) (*ImportData, error) {
|
||||
fileData, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file %s: %w", file, err)
|
||||
}
|
||||
|
||||
// Auto-detect format if not specified
|
||||
if format == "" {
|
||||
format = detectFormat(file, fileData)
|
||||
}
|
||||
|
||||
var data ImportData
|
||||
switch strings.ToLower(format) {
|
||||
case "json":
|
||||
if err := json.Unmarshal(fileData, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
case "yaml", "yml":
|
||||
if err := yaml.Unmarshal(fileData, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func readConversationFile(file, format string) (*ConversationData, error) {
|
||||
fileData, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file %s: %w", file, err)
|
||||
}
|
||||
|
||||
// Auto-detect format if not specified
|
||||
if format == "" {
|
||||
format = detectFormat(file, fileData)
|
||||
}
|
||||
|
||||
var data ConversationData
|
||||
switch strings.ToLower(format) {
|
||||
case "json":
|
||||
if err := json.Unmarshal(fileData, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
case "yaml", "yml":
|
||||
if err := yaml.Unmarshal(fileData, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML: %w", err)
|
||||
}
|
||||
case "markdown", "md":
|
||||
return nil, fmt.Errorf("markdown import for conversations is not yet implemented")
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func detectFormat(filename string, data []byte) string {
|
||||
// First try file extension
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".json":
|
||||
return "json"
|
||||
case ".yaml", ".yml":
|
||||
return "yaml"
|
||||
case ".md", ".markdown":
|
||||
return "markdown"
|
||||
}
|
||||
|
||||
// Try to detect from content
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) > 0 {
|
||||
if data[0] == '{' || data[0] == '[' {
|
||||
return "json"
|
||||
}
|
||||
if strings.HasPrefix(string(data), "---") || strings.Contains(string(data), ":") {
|
||||
return "yaml"
|
||||
}
|
||||
}
|
||||
|
||||
return "json" // default fallback
|
||||
}
|
||||
|
||||
func validateImportData(data *ImportData) error {
|
||||
if data == nil {
|
||||
return fmt.Errorf("import data is nil")
|
||||
}
|
||||
|
||||
if len(data.Sessions) == 0 {
|
||||
return fmt.Errorf("no sessions to import")
|
||||
}
|
||||
|
||||
// Validate session structure
|
||||
for i, sess := range data.Sessions {
|
||||
if err := validateImportSessionHierarchy(sess, ""); err != nil {
|
||||
return fmt.Errorf("session %d validation failed: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConversationData(data *ConversationData) error {
|
||||
if data == nil {
|
||||
return fmt.Errorf("conversation data is nil")
|
||||
}
|
||||
|
||||
if data.Session.Title == "" {
|
||||
return fmt.Errorf("session title is required")
|
||||
}
|
||||
|
||||
if len(data.Messages) == 0 {
|
||||
return fmt.Errorf("no messages to import")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateImportSessionHierarchy(sess ImportSession, expectedParent string) error {
|
||||
if sess.ID == "" {
|
||||
return fmt.Errorf("session ID is required")
|
||||
}
|
||||
|
||||
if sess.Title == "" {
|
||||
return fmt.Errorf("session title is required")
|
||||
}
|
||||
|
||||
// For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
|
||||
if expectedParent == "" {
|
||||
if sess.ParentSessionID != "" {
|
||||
return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
|
||||
}
|
||||
} else {
|
||||
// For child sessions, parent should match expected parent
|
||||
if sess.ParentSessionID != expectedParent {
|
||||
return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate children
|
||||
for _, child := range sess.Children {
|
||||
if err := validateImportSessionHierarchy(child, sess.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSessionHierarchy(sess SessionWithChildren, expectedParent string) error {
|
||||
if sess.ID == "" {
|
||||
return fmt.Errorf("session ID is required")
|
||||
}
|
||||
|
||||
if sess.Title == "" {
|
||||
return fmt.Errorf("session title is required")
|
||||
}
|
||||
|
||||
// For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
|
||||
if expectedParent == "" {
|
||||
if sess.ParentSessionID != "" {
|
||||
return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
|
||||
}
|
||||
} else {
|
||||
// For child sessions, parent should match expected parent
|
||||
if sess.ParentSessionID != expectedParent {
|
||||
return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate children
|
||||
for _, child := range sess.Children {
|
||||
if err := validateSessionHierarchy(child, sess.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func countTotalImportSessions(sessions []ImportSession) int {
|
||||
count := len(sessions)
|
||||
for _, sess := range sessions {
|
||||
count += countTotalImportSessions(sess.Children)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countTotalSessions(sessions []SessionWithChildren) int {
|
||||
count := len(sessions)
|
||||
for _, sess := range sessions {
|
||||
count += countTotalSessions(sess.Children)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func importSessions(ctx context.Context, sessionService session.Service, messageService message.Service, data *ImportData) (ImportResult, error) {
|
||||
result := ImportResult{
|
||||
TotalSessions: countTotalImportSessions(data.Sessions),
|
||||
SessionMapping: make(map[string]string),
|
||||
}
|
||||
|
||||
// Import sessions recursively, starting with top-level sessions
|
||||
for _, sess := range data.Sessions {
|
||||
err := importImportSessionWithChildren(ctx, sessionService, messageService, sess, "", &result)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("failed to import session %s: %v", sess.ID, err))
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func importConversation(ctx context.Context, sessionService session.Service, messageService message.Service, data *ConversationData) (string, int, error) {
|
||||
// Generate new session ID
|
||||
newSessionID := uuid.New().String()
|
||||
|
||||
// Create the session using the low-level database API
|
||||
cwd, err := getCwd()
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
cfg, err := config.Init(cwd, false)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
queries := db.New(conn)
|
||||
|
||||
// Create session with all original metadata
|
||||
_, err = queries.CreateSession(ctx, db.CreateSessionParams{
|
||||
ID: newSessionID,
|
||||
ParentSessionID: sql.NullString{Valid: false},
|
||||
Title: data.Session.Title,
|
||||
MessageCount: data.Session.MessageCount,
|
||||
PromptTokens: data.Session.PromptTokens,
|
||||
CompletionTokens: data.Session.CompletionTokens,
|
||||
Cost: data.Session.Cost,
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
// Import messages
|
||||
messageCount := 0
|
||||
for _, msg := range data.Messages {
|
||||
// Generate new message ID
|
||||
newMessageID := uuid.New().String()
|
||||
|
||||
// Marshal message parts
|
||||
partsJSON, err := json.Marshal(msg.Parts)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("failed to marshal message parts: %w", err)
|
||||
}
|
||||
|
||||
// Create message
|
||||
_, err = queries.CreateMessage(ctx, db.CreateMessageParams{
|
||||
ID: newMessageID,
|
||||
SessionID: newSessionID,
|
||||
Role: string(msg.Role),
|
||||
Parts: string(partsJSON),
|
||||
Model: sql.NullString{String: msg.Model, Valid: msg.Model != ""},
|
||||
Provider: sql.NullString{String: msg.Provider, Valid: msg.Provider != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("failed to create message: %w", err)
|
||||
}
|
||||
messageCount++
|
||||
}
|
||||
|
||||
return newSessionID, messageCount, nil
|
||||
}
|
||||
|
||||
func importImportSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess ImportSession, parentID string, result *ImportResult) error {
|
||||
// Generate new session ID
|
||||
newSessionID := uuid.New().String()
|
||||
result.SessionMapping[sess.ID] = newSessionID
|
||||
|
||||
// Create the session using the low-level database API to preserve metadata
|
||||
cwd, err := getCwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg, err := config.Init(cwd, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := db.New(conn)
|
||||
|
||||
// Create session with all original metadata
|
||||
parentSessionID := sql.NullString{Valid: false}
|
||||
if parentID != "" {
|
||||
parentSessionID = sql.NullString{String: parentID, Valid: true}
|
||||
}
|
||||
|
||||
_, err = queries.CreateSession(ctx, db.CreateSessionParams{
|
||||
ID: newSessionID,
|
||||
ParentSessionID: parentSessionID,
|
||||
Title: sess.Title,
|
||||
MessageCount: sess.MessageCount,
|
||||
PromptTokens: sess.PromptTokens,
|
||||
CompletionTokens: sess.CompletionTokens,
|
||||
Cost: sess.Cost,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
result.ImportedSessions++
|
||||
|
||||
// Import children recursively
|
||||
for _, child := range sess.Children {
|
||||
err := importImportSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func importSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess SessionWithChildren, parentID string, result *ImportResult) error {
|
||||
// Generate new session ID
|
||||
newSessionID := uuid.New().String()
|
||||
result.SessionMapping[sess.ID] = newSessionID
|
||||
|
||||
// Create the session using the low-level database API to preserve metadata
|
||||
cwd, err := getCwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg, err := config.Init(cwd, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := db.New(conn)
|
||||
|
||||
// Create session with all original metadata
|
||||
parentSessionID := sql.NullString{Valid: false}
|
||||
if parentID != "" {
|
||||
parentSessionID = sql.NullString{String: parentID, Valid: true}
|
||||
}
|
||||
|
||||
_, err = queries.CreateSession(ctx, db.CreateSessionParams{
|
||||
ID: newSessionID,
|
||||
ParentSessionID: parentSessionID,
|
||||
Title: sess.Title,
|
||||
MessageCount: sess.MessageCount,
|
||||
PromptTokens: sess.PromptTokens,
|
||||
CompletionTokens: sess.CompletionTokens,
|
||||
Cost: sess.Cost,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
result.ImportedSessions++
|
||||
|
||||
// Import children recursively
|
||||
for _, child := range sess.Children {
|
||||
err := importSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
|
||||
if err != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSessionsSearch(ctx context.Context, titlePattern, textPattern, format string) error {
|
||||
sessionService, err := createSessionService(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sessions []session.Session
|
||||
|
||||
// Determine which search method to use based on provided patterns
|
||||
if titlePattern != "" && textPattern != "" {
|
||||
sessions, err = sessionService.SearchByTitleAndText(ctx, titlePattern, textPattern)
|
||||
} else if titlePattern != "" {
|
||||
sessions, err = sessionService.SearchByTitle(ctx, titlePattern)
|
||||
} else if textPattern != "" {
|
||||
sessions, err = sessionService.SearchByText(ctx, textPattern)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("search failed: %w", err)
|
||||
}
|
||||
|
||||
return formatSearchResults(sessions, format)
|
||||
}
|
||||
|
||||
func formatSearchResults(sessions []session.Session, format string) error {
|
||||
switch strings.ToLower(format) {
|
||||
case "json":
|
||||
return formatSearchResultsJSON(sessions)
|
||||
case "text":
|
||||
return formatSearchResultsText(sessions)
|
||||
default:
|
||||
return fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
func formatSearchResultsJSON(sessions []session.Session) error {
|
||||
data, err := json.MarshalIndent(sessions, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal JSON: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatSearchResultsText(sessions []session.Session) error {
|
||||
if len(sessions) == 0 {
|
||||
fmt.Println("No sessions found matching the search criteria.")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Found %d session(s):\n\n", len(sessions))
|
||||
for _, sess := range sessions {
|
||||
fmt.Printf("• %s (ID: %s)\n", sess.Title, sess.ID)
|
||||
fmt.Printf(" Messages: %d, Cost: $%.4f\n", sess.MessageCount, sess.Cost)
|
||||
fmt.Printf(" Created: %s\n", formatTimestamp(sess.CreatedAt))
|
||||
if sess.ParentSessionID != "" {
|
||||
fmt.Printf(" Parent: %s\n", sess.ParentSessionID)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSessionsStats(ctx context.Context, format, groupBy string) error {
|
||||
// Get database connection
|
||||
cwd, err := getCwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg, err := config.Init(cwd, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := db.New(conn)
|
||||
|
||||
// Handle grouped statistics
|
||||
if groupBy != "" {
|
||||
return runGroupedStats(ctx, queries, format, groupBy)
|
||||
}
|
||||
|
||||
// Get overall statistics
|
||||
statsRow, err := queries.GetSessionStats(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session stats: %w", err)
|
||||
}
|
||||
|
||||
// Convert to our struct, handling NULL values
|
||||
stats := SessionStats{
|
||||
TotalSessions: statsRow.TotalSessions,
|
||||
TotalMessages: convertNullFloat64ToInt64(statsRow.TotalMessages),
|
||||
TotalPromptTokens: convertNullFloat64ToInt64(statsRow.TotalPromptTokens),
|
||||
TotalCompletionTokens: convertNullFloat64ToInt64(statsRow.TotalCompletionTokens),
|
||||
TotalCost: convertNullFloat64(statsRow.TotalCost),
|
||||
AvgCostPerSession: convertNullFloat64(statsRow.AvgCostPerSession),
|
||||
}
|
||||
|
||||
return formatStats(stats, format)
|
||||
}
|
||||
|
||||
func runGroupedStats(ctx context.Context, queries *db.Queries, format, groupBy string) error {
|
||||
var groupedStats []GroupedSessionStats
|
||||
|
||||
switch strings.ToLower(groupBy) {
|
||||
case "day":
|
||||
rows, err := queries.GetSessionStatsByDay(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get daily stats: %w", err)
|
||||
}
|
||||
groupedStats = convertDayStatsRows(rows)
|
||||
case "week":
|
||||
rows, err := queries.GetSessionStatsByWeek(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get weekly stats: %w", err)
|
||||
}
|
||||
groupedStats = convertWeekStatsRows(rows)
|
||||
case "month":
|
||||
rows, err := queries.GetSessionStatsByMonth(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get monthly stats: %w", err)
|
||||
}
|
||||
groupedStats = convertMonthStatsRows(rows)
|
||||
default:
|
||||
return fmt.Errorf("unsupported group-by value: %s. Valid values are: day, week, month", groupBy)
|
||||
}
|
||||
|
||||
return formatGroupedStats(groupedStats, format, groupBy)
|
||||
}
|
||||
|
||||
func convertNullFloat64(val sql.NullFloat64) float64 {
|
||||
if val.Valid {
|
||||
return val.Float64
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
func convertNullFloat64ToInt64(val sql.NullFloat64) int64 {
|
||||
if val.Valid {
|
||||
return int64(val.Float64)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func convertDayStatsRows(rows []db.GetSessionStatsByDayRow) []GroupedSessionStats {
|
||||
result := make([]GroupedSessionStats, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
stats := GroupedSessionStats{
|
||||
Period: fmt.Sprintf("%v", row.Day),
|
||||
SessionCount: row.SessionCount,
|
||||
MessageCount: convertNullFloat64ToInt64(row.MessageCount),
|
||||
PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
|
||||
CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
|
||||
TotalCost: convertNullFloat64(row.TotalCost),
|
||||
AvgCost: convertNullFloat64(row.AvgCost),
|
||||
}
|
||||
result = append(result, stats)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func convertWeekStatsRows(rows []db.GetSessionStatsByWeekRow) []GroupedSessionStats {
|
||||
result := make([]GroupedSessionStats, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
stats := GroupedSessionStats{
|
||||
Period: fmt.Sprintf("%v", row.WeekStart),
|
||||
SessionCount: row.SessionCount,
|
||||
MessageCount: convertNullFloat64ToInt64(row.MessageCount),
|
||||
PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
|
||||
CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
|
||||
TotalCost: convertNullFloat64(row.TotalCost),
|
||||
AvgCost: convertNullFloat64(row.AvgCost),
|
||||
}
|
||||
result = append(result, stats)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func convertMonthStatsRows(rows []db.GetSessionStatsByMonthRow) []GroupedSessionStats {
|
||||
result := make([]GroupedSessionStats, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
stats := GroupedSessionStats{
|
||||
Period: fmt.Sprintf("%v", row.Month),
|
||||
SessionCount: row.SessionCount,
|
||||
MessageCount: convertNullFloat64ToInt64(row.MessageCount),
|
||||
PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
|
||||
CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
|
||||
TotalCost: convertNullFloat64(row.TotalCost),
|
||||
AvgCost: convertNullFloat64(row.AvgCost),
|
||||
}
|
||||
result = append(result, stats)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func formatStats(stats SessionStats, format string) error {
|
||||
switch strings.ToLower(format) {
|
||||
case "json":
|
||||
return formatStatsJSON(stats)
|
||||
case "text":
|
||||
return formatStatsText(stats)
|
||||
default:
|
||||
return fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
func formatGroupedStats(stats []GroupedSessionStats, format, groupBy string) error {
|
||||
switch strings.ToLower(format) {
|
||||
case "json":
|
||||
return formatGroupedStatsJSON(stats)
|
||||
case "text":
|
||||
return formatGroupedStatsText(stats, groupBy)
|
||||
default:
|
||||
return fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
func formatStatsJSON(stats SessionStats) error {
|
||||
data, err := json.MarshalIndent(stats, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal JSON: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatStatsText(stats SessionStats) error {
|
||||
if stats.TotalSessions == 0 {
|
||||
fmt.Println("No sessions found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("Session Statistics")
|
||||
fmt.Println("==================")
|
||||
fmt.Printf("Total Sessions: %d\n", stats.TotalSessions)
|
||||
fmt.Printf("Total Messages: %d\n", stats.TotalMessages)
|
||||
fmt.Printf("Total Prompt Tokens: %d\n", stats.TotalPromptTokens)
|
||||
fmt.Printf("Total Completion Tokens: %d\n", stats.TotalCompletionTokens)
|
||||
fmt.Printf("Total Cost: $%.4f\n", stats.TotalCost)
|
||||
fmt.Printf("Average Cost/Session: $%.4f\n", stats.AvgCostPerSession)
|
||||
|
||||
totalTokens := stats.TotalPromptTokens + stats.TotalCompletionTokens
|
||||
if totalTokens > 0 {
|
||||
fmt.Printf("Total Tokens: %d\n", totalTokens)
|
||||
fmt.Printf("Average Tokens/Session: %.1f\n", float64(totalTokens)/float64(stats.TotalSessions))
|
||||
}
|
||||
|
||||
if stats.TotalSessions > 0 {
|
||||
fmt.Printf("Average Messages/Session: %.1f\n", float64(stats.TotalMessages)/float64(stats.TotalSessions))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatGroupedStatsJSON(stats []GroupedSessionStats) error {
|
||||
data, err := json.MarshalIndent(stats, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal JSON: %w", err)
|
||||
}
|
||||
fmt.Println(string(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatGroupedStatsText(stats []GroupedSessionStats, groupBy string) error {
|
||||
if len(stats) == 0 {
|
||||
fmt.Printf("No sessions found for grouping by %s.\n", groupBy)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Session Statistics (Grouped by %s)\n", strings.ToUpper(groupBy[:1])+groupBy[1:])
|
||||
fmt.Println(strings.Repeat("=", 30+len(groupBy)))
|
||||
fmt.Println()
|
||||
|
||||
for _, stat := range stats {
|
||||
fmt.Printf("Period: %s\n", stat.Period)
|
||||
fmt.Printf(" Sessions: %d\n", stat.SessionCount)
|
||||
fmt.Printf(" Messages: %d\n", stat.MessageCount)
|
||||
fmt.Printf(" Prompt Tokens: %d\n", stat.PromptTokens)
|
||||
fmt.Printf(" Completion Tokens: %d\n", stat.CompletionTokens)
|
||||
fmt.Printf(" Total Cost: $%.4f\n", stat.TotalCost)
|
||||
fmt.Printf(" Average Cost: $%.4f\n", stat.AvgCost)
|
||||
totalTokens := stat.PromptTokens + stat.CompletionTokens
|
||||
if totalTokens > 0 {
|
||||
fmt.Printf(" Total Tokens: %d\n", totalTokens)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,6 +60,18 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
|
||||
if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err)
|
||||
}
|
||||
if q.getSessionStatsStmt, err = db.PrepareContext(ctx, getSessionStats); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetSessionStats: %w", err)
|
||||
}
|
||||
if q.getSessionStatsByDayStmt, err = db.PrepareContext(ctx, getSessionStatsByDay); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetSessionStatsByDay: %w", err)
|
||||
}
|
||||
if q.getSessionStatsByMonthStmt, err = db.PrepareContext(ctx, getSessionStatsByMonth); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetSessionStatsByMonth: %w", err)
|
||||
}
|
||||
if q.getSessionStatsByWeekStmt, err = db.PrepareContext(ctx, getSessionStatsByWeek); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetSessionStatsByWeek: %w", err)
|
||||
}
|
||||
if q.listAllSessionsStmt, err = db.PrepareContext(ctx, listAllSessions); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListAllSessions: %w", err)
|
||||
}
|
||||
@@ -84,6 +96,15 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
|
||||
if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
|
||||
}
|
||||
if q.searchSessionsByTextStmt, err = db.PrepareContext(ctx, searchSessionsByText); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query SearchSessionsByText: %w", err)
|
||||
}
|
||||
if q.searchSessionsByTitleStmt, err = db.PrepareContext(ctx, searchSessionsByTitle); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query SearchSessionsByTitle: %w", err)
|
||||
}
|
||||
if q.searchSessionsByTitleAndTextStmt, err = db.PrepareContext(ctx, searchSessionsByTitleAndText); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query SearchSessionsByTitleAndText: %w", err)
|
||||
}
|
||||
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
|
||||
}
|
||||
@@ -155,6 +176,26 @@ func (q *Queries) Close() error {
|
||||
err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getSessionStatsStmt != nil {
|
||||
if cerr := q.getSessionStatsStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getSessionStatsStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getSessionStatsByDayStmt != nil {
|
||||
if cerr := q.getSessionStatsByDayStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getSessionStatsByDayStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getSessionStatsByMonthStmt != nil {
|
||||
if cerr := q.getSessionStatsByMonthStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getSessionStatsByMonthStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getSessionStatsByWeekStmt != nil {
|
||||
if cerr := q.getSessionStatsByWeekStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getSessionStatsByWeekStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listAllSessionsStmt != nil {
|
||||
if cerr := q.listAllSessionsStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listAllSessionsStmt: %w", cerr)
|
||||
@@ -195,6 +236,21 @@ func (q *Queries) Close() error {
|
||||
err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.searchSessionsByTextStmt != nil {
|
||||
if cerr := q.searchSessionsByTextStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing searchSessionsByTextStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.searchSessionsByTitleStmt != nil {
|
||||
if cerr := q.searchSessionsByTitleStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing searchSessionsByTitleStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.searchSessionsByTitleAndTextStmt != nil {
|
||||
if cerr := q.searchSessionsByTitleAndTextStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing searchSessionsByTitleAndTextStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.updateMessageStmt != nil {
|
||||
if cerr := q.updateMessageStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
|
||||
@@ -242,57 +298,71 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
tx *sql.Tx
|
||||
createFileStmt *sql.Stmt
|
||||
createMessageStmt *sql.Stmt
|
||||
createSessionStmt *sql.Stmt
|
||||
deleteFileStmt *sql.Stmt
|
||||
deleteMessageStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
deleteSessionFilesStmt *sql.Stmt
|
||||
deleteSessionMessagesStmt *sql.Stmt
|
||||
getFileStmt *sql.Stmt
|
||||
getFileByPathAndSessionStmt *sql.Stmt
|
||||
getMessageStmt *sql.Stmt
|
||||
getSessionByIDStmt *sql.Stmt
|
||||
listAllSessionsStmt *sql.Stmt
|
||||
listChildSessionsStmt *sql.Stmt
|
||||
listFilesByPathStmt *sql.Stmt
|
||||
listFilesBySessionStmt *sql.Stmt
|
||||
listLatestSessionFilesStmt *sql.Stmt
|
||||
listMessagesBySessionStmt *sql.Stmt
|
||||
listNewFilesStmt *sql.Stmt
|
||||
listSessionsStmt *sql.Stmt
|
||||
updateMessageStmt *sql.Stmt
|
||||
updateSessionStmt *sql.Stmt
|
||||
db DBTX
|
||||
tx *sql.Tx
|
||||
createFileStmt *sql.Stmt
|
||||
createMessageStmt *sql.Stmt
|
||||
createSessionStmt *sql.Stmt
|
||||
deleteFileStmt *sql.Stmt
|
||||
deleteMessageStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
deleteSessionFilesStmt *sql.Stmt
|
||||
deleteSessionMessagesStmt *sql.Stmt
|
||||
getFileStmt *sql.Stmt
|
||||
getFileByPathAndSessionStmt *sql.Stmt
|
||||
getMessageStmt *sql.Stmt
|
||||
getSessionByIDStmt *sql.Stmt
|
||||
getSessionStatsStmt *sql.Stmt
|
||||
getSessionStatsByDayStmt *sql.Stmt
|
||||
getSessionStatsByMonthStmt *sql.Stmt
|
||||
getSessionStatsByWeekStmt *sql.Stmt
|
||||
listAllSessionsStmt *sql.Stmt
|
||||
listChildSessionsStmt *sql.Stmt
|
||||
listFilesByPathStmt *sql.Stmt
|
||||
listFilesBySessionStmt *sql.Stmt
|
||||
listLatestSessionFilesStmt *sql.Stmt
|
||||
listMessagesBySessionStmt *sql.Stmt
|
||||
listNewFilesStmt *sql.Stmt
|
||||
listSessionsStmt *sql.Stmt
|
||||
searchSessionsByTextStmt *sql.Stmt
|
||||
searchSessionsByTitleStmt *sql.Stmt
|
||||
searchSessionsByTitleAndTextStmt *sql.Stmt
|
||||
updateMessageStmt *sql.Stmt
|
||||
updateSessionStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
tx: tx,
|
||||
createFileStmt: q.createFileStmt,
|
||||
createMessageStmt: q.createMessageStmt,
|
||||
createSessionStmt: q.createSessionStmt,
|
||||
deleteFileStmt: q.deleteFileStmt,
|
||||
deleteMessageStmt: q.deleteMessageStmt,
|
||||
deleteSessionStmt: q.deleteSessionStmt,
|
||||
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
|
||||
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
|
||||
getFileStmt: q.getFileStmt,
|
||||
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
|
||||
getMessageStmt: q.getMessageStmt,
|
||||
getSessionByIDStmt: q.getSessionByIDStmt,
|
||||
listAllSessionsStmt: q.listAllSessionsStmt,
|
||||
listChildSessionsStmt: q.listChildSessionsStmt,
|
||||
listFilesByPathStmt: q.listFilesByPathStmt,
|
||||
listFilesBySessionStmt: q.listFilesBySessionStmt,
|
||||
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
|
||||
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
|
||||
listNewFilesStmt: q.listNewFilesStmt,
|
||||
listSessionsStmt: q.listSessionsStmt,
|
||||
updateMessageStmt: q.updateMessageStmt,
|
||||
updateSessionStmt: q.updateSessionStmt,
|
||||
db: tx,
|
||||
tx: tx,
|
||||
createFileStmt: q.createFileStmt,
|
||||
createMessageStmt: q.createMessageStmt,
|
||||
createSessionStmt: q.createSessionStmt,
|
||||
deleteFileStmt: q.deleteFileStmt,
|
||||
deleteMessageStmt: q.deleteMessageStmt,
|
||||
deleteSessionStmt: q.deleteSessionStmt,
|
||||
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
|
||||
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
|
||||
getFileStmt: q.getFileStmt,
|
||||
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
|
||||
getMessageStmt: q.getMessageStmt,
|
||||
getSessionByIDStmt: q.getSessionByIDStmt,
|
||||
getSessionStatsStmt: q.getSessionStatsStmt,
|
||||
getSessionStatsByDayStmt: q.getSessionStatsByDayStmt,
|
||||
getSessionStatsByMonthStmt: q.getSessionStatsByMonthStmt,
|
||||
getSessionStatsByWeekStmt: q.getSessionStatsByWeekStmt,
|
||||
listAllSessionsStmt: q.listAllSessionsStmt,
|
||||
listChildSessionsStmt: q.listChildSessionsStmt,
|
||||
listFilesByPathStmt: q.listFilesByPathStmt,
|
||||
listFilesBySessionStmt: q.listFilesBySessionStmt,
|
||||
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
|
||||
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
|
||||
listNewFilesStmt: q.listNewFilesStmt,
|
||||
listSessionsStmt: q.listSessionsStmt,
|
||||
searchSessionsByTextStmt: q.searchSessionsByTextStmt,
|
||||
searchSessionsByTitleStmt: q.searchSessionsByTitleStmt,
|
||||
searchSessionsByTitleAndTextStmt: q.searchSessionsByTitleAndTextStmt,
|
||||
updateMessageStmt: q.updateMessageStmt,
|
||||
updateSessionStmt: q.updateSessionStmt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,10 @@ type Querier interface {
|
||||
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
|
||||
GetMessage(ctx context.Context, id string) (Message, error)
|
||||
GetSessionByID(ctx context.Context, id string) (Session, error)
|
||||
GetSessionStats(ctx context.Context) (GetSessionStatsRow, error)
|
||||
GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error)
|
||||
GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error)
|
||||
GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error)
|
||||
ListAllSessions(ctx context.Context) ([]Session, error)
|
||||
ListChildSessions(ctx context.Context, parentSessionID sql.NullString) ([]Session, error)
|
||||
ListFilesByPath(ctx context.Context, path string) ([]File, error)
|
||||
@@ -30,6 +34,9 @@ type Querier interface {
|
||||
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
|
||||
ListNewFiles(ctx context.Context) ([]File, error)
|
||||
ListSessions(ctx context.Context) ([]Session, error)
|
||||
SearchSessionsByText(ctx context.Context, parts string) ([]Session, error)
|
||||
SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error)
|
||||
SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error)
|
||||
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
|
||||
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
|
||||
}
|
||||
|
||||
@@ -106,6 +106,205 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSessionStats = `-- name: GetSessionStats :one
|
||||
SELECT
|
||||
COUNT(*) as total_sessions,
|
||||
SUM(message_count) as total_messages,
|
||||
SUM(prompt_tokens) as total_prompt_tokens,
|
||||
SUM(completion_tokens) as total_completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost_per_session
|
||||
FROM sessions
|
||||
`
|
||||
|
||||
type GetSessionStatsRow struct {
|
||||
TotalSessions int64 `json:"total_sessions"`
|
||||
TotalMessages sql.NullFloat64 `json:"total_messages"`
|
||||
TotalPromptTokens sql.NullFloat64 `json:"total_prompt_tokens"`
|
||||
TotalCompletionTokens sql.NullFloat64 `json:"total_completion_tokens"`
|
||||
TotalCost sql.NullFloat64 `json:"total_cost"`
|
||||
AvgCostPerSession sql.NullFloat64 `json:"avg_cost_per_session"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSessionStats(ctx context.Context) (GetSessionStatsRow, error) {
|
||||
row := q.queryRow(ctx, q.getSessionStatsStmt, getSessionStats)
|
||||
var i GetSessionStatsRow
|
||||
err := row.Scan(
|
||||
&i.TotalSessions,
|
||||
&i.TotalMessages,
|
||||
&i.TotalPromptTokens,
|
||||
&i.TotalCompletionTokens,
|
||||
&i.TotalCost,
|
||||
&i.AvgCostPerSession,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSessionStatsByDay = `-- name: GetSessionStatsByDay :many
|
||||
SELECT
|
||||
date(created_at, 'unixepoch') as day,
|
||||
COUNT(*) as session_count,
|
||||
SUM(message_count) as message_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost
|
||||
FROM sessions
|
||||
GROUP BY date(created_at, 'unixepoch')
|
||||
ORDER BY day DESC
|
||||
`
|
||||
|
||||
type GetSessionStatsByDayRow struct {
|
||||
Day interface{} `json:"day"`
|
||||
SessionCount int64 `json:"session_count"`
|
||||
MessageCount sql.NullFloat64 `json:"message_count"`
|
||||
PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
|
||||
CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
|
||||
TotalCost sql.NullFloat64 `json:"total_cost"`
|
||||
AvgCost sql.NullFloat64 `json:"avg_cost"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error) {
|
||||
rows, err := q.query(ctx, q.getSessionStatsByDayStmt, getSessionStatsByDay)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []GetSessionStatsByDayRow{}
|
||||
for rows.Next() {
|
||||
var i GetSessionStatsByDayRow
|
||||
if err := rows.Scan(
|
||||
&i.Day,
|
||||
&i.SessionCount,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.TotalCost,
|
||||
&i.AvgCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getSessionStatsByMonth = `-- name: GetSessionStatsByMonth :many
|
||||
SELECT
|
||||
strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month,
|
||||
COUNT(*) as session_count,
|
||||
SUM(message_count) as message_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost
|
||||
FROM sessions
|
||||
GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch'))
|
||||
ORDER BY month DESC
|
||||
`
|
||||
|
||||
type GetSessionStatsByMonthRow struct {
|
||||
Month interface{} `json:"month"`
|
||||
SessionCount int64 `json:"session_count"`
|
||||
MessageCount sql.NullFloat64 `json:"message_count"`
|
||||
PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
|
||||
CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
|
||||
TotalCost sql.NullFloat64 `json:"total_cost"`
|
||||
AvgCost sql.NullFloat64 `json:"avg_cost"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error) {
|
||||
rows, err := q.query(ctx, q.getSessionStatsByMonthStmt, getSessionStatsByMonth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []GetSessionStatsByMonthRow{}
|
||||
for rows.Next() {
|
||||
var i GetSessionStatsByMonthRow
|
||||
if err := rows.Scan(
|
||||
&i.Month,
|
||||
&i.SessionCount,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.TotalCost,
|
||||
&i.AvgCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getSessionStatsByWeek = `-- name: GetSessionStatsByWeek :many
|
||||
SELECT
|
||||
date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start,
|
||||
COUNT(*) as session_count,
|
||||
SUM(message_count) as message_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost
|
||||
FROM sessions
|
||||
GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days')
|
||||
ORDER BY week_start DESC
|
||||
`
|
||||
|
||||
type GetSessionStatsByWeekRow struct {
|
||||
WeekStart interface{} `json:"week_start"`
|
||||
SessionCount int64 `json:"session_count"`
|
||||
MessageCount sql.NullFloat64 `json:"message_count"`
|
||||
PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
|
||||
CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
|
||||
TotalCost sql.NullFloat64 `json:"total_cost"`
|
||||
AvgCost sql.NullFloat64 `json:"avg_cost"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error) {
|
||||
rows, err := q.query(ctx, q.getSessionStatsByWeekStmt, getSessionStatsByWeek)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []GetSessionStatsByWeekRow{}
|
||||
for rows.Next() {
|
||||
var i GetSessionStatsByWeekRow
|
||||
if err := rows.Scan(
|
||||
&i.WeekStart,
|
||||
&i.SessionCount,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.TotalCost,
|
||||
&i.AvgCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAllSessions = `-- name: ListAllSessions :many
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
|
||||
FROM sessions
|
||||
@@ -228,6 +427,136 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const searchSessionsByText = `-- name: SearchSessionsByText :many
|
||||
SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id
|
||||
FROM sessions s
|
||||
JOIN messages m ON s.id = m.session_id
|
||||
WHERE m.parts LIKE ?
|
||||
ORDER BY s.created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) SearchSessionsByText(ctx context.Context, parts string) ([]Session, error) {
|
||||
rows, err := q.query(ctx, q.searchSessionsByTextStmt, searchSessionsByText, parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Session{}
|
||||
for rows.Next() {
|
||||
var i Session
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
&i.SummaryMessageID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const searchSessionsByTitle = `-- name: SearchSessionsByTitle :many
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
|
||||
FROM sessions
|
||||
WHERE title LIKE ?
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error) {
|
||||
rows, err := q.query(ctx, q.searchSessionsByTitleStmt, searchSessionsByTitle, title)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Session{}
|
||||
for rows.Next() {
|
||||
var i Session
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
&i.SummaryMessageID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const searchSessionsByTitleAndText = `-- name: SearchSessionsByTitleAndText :many
|
||||
SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id
|
||||
FROM sessions s
|
||||
JOIN messages m ON s.id = m.session_id
|
||||
WHERE s.title LIKE ? AND m.parts LIKE ?
|
||||
ORDER BY s.created_at DESC
|
||||
`
|
||||
|
||||
type SearchSessionsByTitleAndTextParams struct {
|
||||
Title string `json:"title"`
|
||||
Parts string `json:"parts"`
|
||||
}
|
||||
|
||||
func (q *Queries) SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error) {
|
||||
rows, err := q.query(ctx, q.searchSessionsByTitleAndTextStmt, searchSessionsByTitleAndText, arg.Title, arg.Parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Session{}
|
||||
for rows.Next() {
|
||||
var i Session
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
&i.SummaryMessageID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateSession = `-- name: UpdateSession :one
|
||||
UPDATE sessions
|
||||
SET
|
||||
|
||||
@@ -60,3 +60,72 @@ ORDER BY created_at ASC;
|
||||
SELECT *
|
||||
FROM sessions
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: SearchSessionsByTitle :many
|
||||
SELECT *
|
||||
FROM sessions
|
||||
WHERE title LIKE ?
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: SearchSessionsByTitleAndText :many
|
||||
SELECT DISTINCT s.*
|
||||
FROM sessions s
|
||||
JOIN messages m ON s.id = m.session_id
|
||||
WHERE s.title LIKE ? AND m.parts LIKE ?
|
||||
ORDER BY s.created_at DESC;
|
||||
|
||||
-- name: SearchSessionsByText :many
|
||||
SELECT DISTINCT s.*
|
||||
FROM sessions s
|
||||
JOIN messages m ON s.id = m.session_id
|
||||
WHERE m.parts LIKE ?
|
||||
ORDER BY s.created_at DESC;
|
||||
|
||||
-- name: GetSessionStats :one
|
||||
SELECT
|
||||
COUNT(*) as total_sessions,
|
||||
SUM(message_count) as total_messages,
|
||||
SUM(prompt_tokens) as total_prompt_tokens,
|
||||
SUM(completion_tokens) as total_completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost_per_session
|
||||
FROM sessions;
|
||||
|
||||
-- name: GetSessionStatsByDay :many
|
||||
SELECT
|
||||
date(created_at, 'unixepoch') as day,
|
||||
COUNT(*) as session_count,
|
||||
SUM(message_count) as message_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost
|
||||
FROM sessions
|
||||
GROUP BY date(created_at, 'unixepoch')
|
||||
ORDER BY day DESC;
|
||||
|
||||
-- name: GetSessionStatsByWeek :many
|
||||
SELECT
|
||||
date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start,
|
||||
COUNT(*) as session_count,
|
||||
SUM(message_count) as message_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost
|
||||
FROM sessions
|
||||
GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days')
|
||||
ORDER BY week_start DESC;
|
||||
|
||||
-- name: GetSessionStatsByMonth :many
|
||||
SELECT
|
||||
strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month,
|
||||
COUNT(*) as session_count,
|
||||
SUM(message_count) as message_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(cost) as avg_cost
|
||||
FROM sessions
|
||||
GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch'))
|
||||
ORDER BY month DESC;
|
||||
|
||||
@@ -33,6 +33,9 @@ type Service interface {
|
||||
ListChildren(ctx context.Context, parentSessionID string) ([]Session, error)
|
||||
Save(ctx context.Context, session Session) (Session, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error)
|
||||
SearchByText(ctx context.Context, textPattern string) ([]Session, error)
|
||||
SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
@@ -161,6 +164,45 @@ func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]S
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *service) SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error) {
|
||||
dbSessions, err := s.q.SearchSessionsByTitle(ctx, "%"+titlePattern+"%")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSession := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSession)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *service) SearchByText(ctx context.Context, textPattern string) ([]Session, error) {
|
||||
dbSessions, err := s.q.SearchSessionsByText(ctx, "%"+textPattern+"%")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSession := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSession)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *service) SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error) {
|
||||
dbSessions, err := s.q.SearchSessionsByTitleAndText(ctx, db.SearchSessionsByTitleAndTextParams{
|
||||
Title: "%" + titlePattern + "%",
|
||||
Parts: "%" + textPattern + "%",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSession := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSession)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s service) fromDBItem(item db.Session) Session {
|
||||
return Session{
|
||||
ID: item.ID,
|
||||
|
||||
297
rfcs/session-import-export.md
Normal file
297
rfcs/session-import-export.md
Normal file
@@ -0,0 +1,297 @@
|
||||
# RFC: Session Import and Export
|
||||
|
||||
## Summary
|
||||
|
||||
This RFC proposes a comprehensive system for importing and exporting conversation sessions in Crush.
|
||||
|
||||
## Background
|
||||
|
||||
Crush manages conversations through a hierarchical session system where:
|
||||
- Sessions contain metadata (title, token counts, cost, timestamps)
|
||||
- Sessions can have parent-child relationships (nested conversations)
|
||||
- Messages within sessions have structured content parts (text, tool calls, reasoning, etc.)
|
||||
- The current implementation provides export functionality but lacks import capabilities
|
||||
|
||||
The latest commit introduced three key commands:
|
||||
- `crush sessions list` - List sessions in various formats
|
||||
- `crush sessions export` - Export all sessions and metadata
|
||||
- `crush sessions export-conversation <session-id>` - Export a single conversation with messages
|
||||
|
||||
## Motivation
|
||||
|
||||
Users need to:
|
||||
1. Share conversations with others
|
||||
2. Use conversation logs for debugging
|
||||
3. Archive and analyze conversation history
|
||||
4. Export data for external tools
|
||||
|
||||
## Detailed Design
|
||||
|
||||
### Core Data Model
|
||||
|
||||
The session export format builds on the existing session structure:
|
||||
|
||||
```go
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
ParentSessionID string `json:"parent_session_id,omitempty"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
SummaryMessageID string `json:"summary_message_id,omitempty"`
|
||||
}
|
||||
|
||||
type SessionWithChildren struct {
|
||||
Session
|
||||
Children []SessionWithChildren `json:"children,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
### Proposed Command Interface
|
||||
|
||||
#### Export Commands (Already Implemented)
|
||||
```bash
|
||||
# List sessions in various formats
|
||||
crush sessions list [--format text|json|yaml|markdown]
|
||||
|
||||
# Export all sessions with metadata
|
||||
crush sessions export [--format json|yaml|markdown]
|
||||
|
||||
# Export single conversation with full message history
|
||||
crush sessions export-conversation <session-id> [--format markdown|json|yaml]
|
||||
```
|
||||
|
||||
#### New Import Commands
|
||||
|
||||
```bash
|
||||
# Import sessions from a file
|
||||
crush sessions import <file> [--format json|yaml] [--dry-run]
|
||||
|
||||
# Import a single conversation
|
||||
crush sessions import-conversation <file> [--format json|yaml|markdown]
|
||||
|
||||
```
|
||||
|
||||
#### Enhanced Inspection Commands
|
||||
|
||||
```bash
|
||||
# Search sessions by criteria
|
||||
crush sessions search [--title <pattern>] [--text <text>] [--format text|json]
|
||||
|
||||
# Show session statistics
|
||||
crush sessions stats [--format text|json] [--group-by day|week|month]
|
||||
|
||||
# Show statistics for a single session
|
||||
crush sessions stats <session-id> [--format text|json]
|
||||
```
|
||||
|
||||
### Import/Export Formats
|
||||
|
||||
#### Full Export Format (JSON)
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"exported_at": "2025-01-27T10:30:00Z",
|
||||
"total_sessions": 15,
|
||||
"sessions": [
|
||||
{
|
||||
"id": "session-123",
|
||||
"parent_session_id": "",
|
||||
"title": "API Design Discussion",
|
||||
"message_count": 8,
|
||||
"prompt_tokens": 1250,
|
||||
"completion_tokens": 890,
|
||||
"cost": 0.0234,
|
||||
"created_at": 1706356200,
|
||||
"updated_at": 1706359800,
|
||||
"children": [
|
||||
{
|
||||
"id": "session-124",
|
||||
"parent_session_id": "session-123",
|
||||
"title": "Implementation Details",
|
||||
"message_count": 4,
|
||||
"prompt_tokens": 650,
|
||||
"completion_tokens": 420,
|
||||
"cost": 0.0145,
|
||||
"created_at": 1706359900,
|
||||
"updated_at": 1706361200
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Conversation Export Format (JSON)
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"session": {
|
||||
"id": "session-123",
|
||||
"title": "API Design Discussion",
|
||||
"created_at": 1706356200,
|
||||
"message_count": 3
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"id": "msg-001",
|
||||
"session_id": "session-123",
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": {
|
||||
"text": "Help me design a REST API for user management"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created_at": 1706356200
|
||||
},
|
||||
{
|
||||
"id": "msg-002",
|
||||
"session_id": "session-123",
|
||||
"role": "assistant",
|
||||
"model": "gpt-4",
|
||||
"provider": "openai",
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": {
|
||||
"text": "I'll help you design a REST API for user management..."
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "finish",
|
||||
"data": {
|
||||
"reason": "stop",
|
||||
"time": 1706356230
|
||||
}
|
||||
}
|
||||
],
|
||||
"created_at": 1706356220
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### API Implementation
|
||||
|
||||
#### Import Service Interface
|
||||
```go
|
||||
type ImportService interface {
|
||||
// Import sessions from structured data
|
||||
ImportSessions(ctx context.Context, data ImportData, opts ImportOptions) (ImportResult, error)
|
||||
|
||||
// Import single conversation
|
||||
ImportConversation(ctx context.Context, data ConversationData, opts ImportOptions) (Session, error)
|
||||
|
||||
// Validate import data without persisting
|
||||
ValidateImport(ctx context.Context, data ImportData) (ValidationResult, error)
|
||||
}
|
||||
|
||||
type ImportOptions struct {
|
||||
ConflictStrategy ConflictStrategy // skip, merge, replace
|
||||
DryRun bool
|
||||
ParentSessionID string // For conversation imports
|
||||
PreserveIDs bool // Whether to preserve original IDs
|
||||
}
|
||||
|
||||
type ConflictStrategy string
|
||||
|
||||
const (
|
||||
ConflictSkip ConflictStrategy = "skip" // Skip existing sessions
|
||||
ConflictMerge ConflictStrategy = "merge" // Merge with existing
|
||||
ConflictReplace ConflictStrategy = "replace" // Replace existing
|
||||
)
|
||||
|
||||
type ImportResult struct {
|
||||
TotalSessions int `json:"total_sessions"`
|
||||
ImportedSessions int `json:"imported_sessions"`
|
||||
SkippedSessions int `json:"skipped_sessions"`
|
||||
Errors []ImportError `json:"errors,omitempty"`
|
||||
SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id
|
||||
}
|
||||
```
|
||||
|
||||
#### Enhanced Export Service
|
||||
```go
|
||||
type ExportService interface {
|
||||
// Export sessions with filtering
|
||||
ExportSessions(ctx context.Context, opts ExportOptions) ([]SessionWithChildren, error)
|
||||
|
||||
// Export conversation with full message history
|
||||
ExportConversation(ctx context.Context, sessionID string, opts ExportOptions) (ConversationExport, error)
|
||||
|
||||
// Search and filter sessions
|
||||
SearchSessions(ctx context.Context, criteria SearchCriteria) ([]Session, error)
|
||||
|
||||
// Get session statistics
|
||||
GetStats(ctx context.Context, opts StatsOptions) (SessionStats, error)
|
||||
}
|
||||
|
||||
type ExportOptions struct {
|
||||
Format string // json, yaml, markdown
|
||||
IncludeMessages bool // Include full message content
|
||||
DateRange DateRange // Filter by date range
|
||||
SessionIDs []string // Export specific sessions
|
||||
}
|
||||
|
||||
type SearchCriteria struct {
|
||||
TitlePattern string
|
||||
DateRange DateRange
|
||||
MinCost float64
|
||||
MaxCost float64
|
||||
ParentSessionID string
|
||||
HasChildren *bool
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Status
|
||||
|
||||
The proposed session import/export functionality has been implemented as a prototype as of July 2025.
|
||||
|
||||
### Implemented Commands
|
||||
|
||||
All new commands have been added to `internal/cmd/sessions.go`:
|
||||
|
||||
- **Import**: `crush sessions import <file> [--format json|yaml] [--dry-run]`
|
||||
- Supports hierarchical session imports with parent-child relationships
|
||||
- Generates new UUIDs to avoid conflicts
|
||||
- Includes validation and dry-run capabilities
|
||||
|
||||
- **Import Conversation**: `crush sessions import-conversation <file> [--format json|yaml]`
|
||||
- Imports single conversations with full message history
|
||||
- Preserves all message content parts and metadata
|
||||
|
||||
- **Search**: `crush sessions search [--title <pattern>] [--text <text>] [--format text|json]`
|
||||
- Case-insensitive title search and message content search
|
||||
- Supports combined search criteria with AND logic
|
||||
|
||||
- **Stats**: `crush sessions stats [--format text|json] [--group-by day|week|month]`
|
||||
- Comprehensive usage statistics (sessions, messages, tokens, costs)
|
||||
- Time-based grouping with efficient database queries
|
||||
|
||||
### Database Changes
|
||||
|
||||
Added new SQL queries in `internal/db/sql/sessions.sql`:
|
||||
- Search queries for title and message content filtering
|
||||
- Statistics aggregation queries with time-based grouping
|
||||
- All queries optimized for performance with proper indexing
|
||||
|
||||
### Database Schema Considerations
|
||||
|
||||
The current schema supports the import/export functionality. Additional indexes may be needed for search performance:
|
||||
|
||||
```sql
|
||||
-- Optimize session searches by date and cost
|
||||
CREATE INDEX idx_sessions_created_at ON sessions(created_at);
|
||||
CREATE INDEX idx_sessions_cost ON sessions(cost);
|
||||
CREATE INDEX idx_sessions_title ON sessions(title COLLATE NOCASE);
|
||||
|
||||
-- Optimize message searches by session
|
||||
CREATE INDEX idx_messages_session_created ON messages(session_id, created_at);
|
||||
```
|
||||
Reference in New Issue
Block a user