Add RFC and implementation of import and stats

This commit is contained in:
Manuel Odendahl
2025-07-27 15:37:00 -04:00
parent 797788a086
commit e0eb35900c
7 changed files with 1784 additions and 48 deletions

View File

@@ -1,10 +1,13 @@
package cmd
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
@@ -12,6 +15,7 @@ import (
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
"github.com/google/uuid"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
)
@@ -22,6 +26,89 @@ type SessionWithChildren struct {
Children []SessionWithChildren `json:"children,omitempty" yaml:"children,omitempty"`
}
// ImportSession represents a session with proper JSON tags for import
type ImportSession struct {
ID string `json:"id"`
ParentSessionID string `json:"parent_session_id"`
Title string `json:"title"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
SummaryMessageID string `json:"summary_message_id,omitempty"`
Children []ImportSession `json:"children,omitempty"`
}
// ImportData represents the full import structure for sessions
type ImportData struct {
Version string `json:"version" yaml:"version"`
ExportedAt string `json:"exported_at,omitempty" yaml:"exported_at,omitempty"`
TotalSessions int `json:"total_sessions,omitempty" yaml:"total_sessions,omitempty"`
Sessions []ImportSession `json:"sessions" yaml:"sessions"`
}
// ImportMessage represents a message with proper JSON tags for import
type ImportMessage struct {
ID string `json:"id"`
Role string `json:"role"`
SessionID string `json:"session_id"`
Parts []interface{} `json:"parts"`
Model string `json:"model,omitempty"`
Provider string `json:"provider,omitempty"`
CreatedAt int64 `json:"created_at"`
}
// ImportSessionInfo represents session info with proper JSON tags for conversation import
type ImportSessionInfo struct {
ID string `json:"id"`
Title string `json:"title"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens,omitempty"`
CompletionTokens int64 `json:"completion_tokens,omitempty"`
Cost float64 `json:"cost,omitempty"`
CreatedAt int64 `json:"created_at"`
}
// ConversationData represents a single conversation import structure
type ConversationData struct {
Version string `json:"version" yaml:"version"`
Session ImportSessionInfo `json:"session" yaml:"session"`
Messages []ImportMessage `json:"messages" yaml:"messages"`
}
// ImportResult contains the results of an import operation
type ImportResult struct {
TotalSessions int `json:"total_sessions"`
ImportedSessions int `json:"imported_sessions"`
SkippedSessions int `json:"skipped_sessions"`
ImportedMessages int `json:"imported_messages"`
Errors []string `json:"errors,omitempty"`
SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id
}
// SessionStats represents aggregated session statistics
type SessionStats struct {
TotalSessions int64 `json:"total_sessions"`
TotalMessages int64 `json:"total_messages"`
TotalPromptTokens int64 `json:"total_prompt_tokens"`
TotalCompletionTokens int64 `json:"total_completion_tokens"`
TotalCost float64 `json:"total_cost"`
AvgCostPerSession float64 `json:"avg_cost_per_session"`
}
// GroupedSessionStats represents statistics grouped by time period
type GroupedSessionStats struct {
Period string `json:"period"`
SessionCount int64 `json:"session_count"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalCost float64 `json:"total_cost"`
AvgCost float64 `json:"avg_cost"`
}
var sessionsCmd = &cobra.Command{
Use: "sessions",
Short: "Manage sessions",
@@ -60,15 +147,80 @@ var exportConversationCmd = &cobra.Command{
},
}
var importCmd = &cobra.Command{
Use: "import <file>",
Short: "Import sessions from a file",
Long: `Import sessions from a JSON or YAML file with hierarchical structure`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
file := args[0]
format, _ := cmd.Flags().GetString("format")
dryRun, _ := cmd.Flags().GetBool("dry-run")
return runImport(cmd.Context(), file, format, dryRun)
},
}
var importConversationCmd = &cobra.Command{
Use: "import-conversation <file>",
Short: "Import a single conversation from a file",
Long: `Import a single conversation with messages from a JSON, YAML, or Markdown file`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
file := args[0]
format, _ := cmd.Flags().GetString("format")
return runImportConversation(cmd.Context(), file, format)
},
}
var searchCmd = &cobra.Command{
Use: "search",
Short: "Search sessions by title or message content",
Long: `Search sessions by title pattern (case-insensitive) or message text content`,
RunE: func(cmd *cobra.Command, args []string) error {
titlePattern, _ := cmd.Flags().GetString("title")
textPattern, _ := cmd.Flags().GetString("text")
format, _ := cmd.Flags().GetString("format")
if titlePattern == "" && textPattern == "" {
return fmt.Errorf("at least one of --title or --text must be provided")
}
return runSessionsSearch(cmd.Context(), titlePattern, textPattern, format)
},
}
var statsCmd = &cobra.Command{
Use: "stats",
Short: "Show session statistics",
Long: `Display aggregated statistics about sessions including total counts, tokens, and costs`,
RunE: func(cmd *cobra.Command, args []string) error {
format, _ := cmd.Flags().GetString("format")
groupBy, _ := cmd.Flags().GetString("group-by")
return runSessionsStats(cmd.Context(), format, groupBy)
},
}
func init() {
rootCmd.AddCommand(sessionsCmd)
sessionsCmd.AddCommand(listCmd)
sessionsCmd.AddCommand(exportCmd)
sessionsCmd.AddCommand(exportConversationCmd)
sessionsCmd.AddCommand(importCmd)
sessionsCmd.AddCommand(importConversationCmd)
sessionsCmd.AddCommand(searchCmd)
sessionsCmd.AddCommand(statsCmd)
listCmd.Flags().StringP("format", "f", "text", "Output format (text, json, yaml, markdown)")
exportCmd.Flags().StringP("format", "f", "json", "Export format (json, yaml, markdown)")
exportConversationCmd.Flags().StringP("format", "f", "markdown", "Export format (markdown, json, yaml)")
importCmd.Flags().StringP("format", "f", "", "Import format (json, yaml) - auto-detected if not specified")
importCmd.Flags().Bool("dry-run", false, "Validate import data without persisting changes")
importConversationCmd.Flags().StringP("format", "f", "", "Import format (json, yaml, markdown) - auto-detected if not specified")
searchCmd.Flags().String("title", "", "Search by session title pattern (case-insensitive substring search)")
searchCmd.Flags().String("text", "", "Search by message text content")
searchCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
statsCmd.Flags().StringP("format", "f", "text", "Output format (text, json)")
statsCmd.Flags().String("group-by", "", "Group statistics by time period (day, week, month)")
}
func runSessionsList(ctx context.Context, format string) error {
@@ -442,3 +594,773 @@ func formatConversationYAML(sess session.Session, messages []message.Message) er
fmt.Println(string(yamlData))
return nil
}
func runImport(ctx context.Context, file, format string, dryRun bool) error {
// Read the file
data, err := readImportFile(file, format)
if err != nil {
return fmt.Errorf("failed to read import file: %w", err)
}
// Validate the data structure
if err := validateImportData(data); err != nil {
return fmt.Errorf("invalid import data: %w", err)
}
if dryRun {
result := ImportResult{
TotalSessions: countTotalImportSessions(data.Sessions),
ImportedSessions: 0,
SkippedSessions: 0,
ImportedMessages: 0,
SessionMapping: make(map[string]string),
}
fmt.Printf("Dry run: Would import %d sessions\n", result.TotalSessions)
return nil
}
// Perform the actual import
sessionService, messageService, err := createServices(ctx)
if err != nil {
return err
}
result, err := importSessions(ctx, sessionService, messageService, data)
if err != nil {
return fmt.Errorf("import failed: %w", err)
}
// Print summary
fmt.Printf("Import completed successfully:\n")
fmt.Printf(" Total sessions processed: %d\n", result.TotalSessions)
fmt.Printf(" Sessions imported: %d\n", result.ImportedSessions)
fmt.Printf(" Sessions skipped: %d\n", result.SkippedSessions)
fmt.Printf(" Messages imported: %d\n", result.ImportedMessages)
if len(result.Errors) > 0 {
fmt.Printf(" Errors encountered: %d\n", len(result.Errors))
for _, errStr := range result.Errors {
fmt.Printf(" - %s\n", errStr)
}
}
return nil
}
func runImportConversation(ctx context.Context, file, format string) error {
// Read the conversation file
convData, err := readConversationFile(file, format)
if err != nil {
return fmt.Errorf("failed to read conversation file: %w", err)
}
// Validate the conversation data
if err := validateConversationData(convData); err != nil {
return fmt.Errorf("invalid conversation data: %w", err)
}
// Import the conversation
sessionService, messageService, err := createServices(ctx)
if err != nil {
return err
}
newSessionID, messageCount, err := importConversation(ctx, sessionService, messageService, convData)
if err != nil {
return fmt.Errorf("conversation import failed: %w", err)
}
fmt.Printf("Conversation imported successfully:\n")
fmt.Printf(" Session ID: %s\n", newSessionID)
fmt.Printf(" Title: %s\n", convData.Session.Title)
fmt.Printf(" Messages imported: %d\n", messageCount)
return nil
}
func readImportFile(file, format string) (*ImportData, error) {
fileData, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %w", file, err)
}
// Auto-detect format if not specified
if format == "" {
format = detectFormat(file, fileData)
}
var data ImportData
switch strings.ToLower(format) {
case "json":
if err := json.Unmarshal(fileData, &data); err != nil {
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
case "yaml", "yml":
if err := yaml.Unmarshal(fileData, &data); err != nil {
return nil, fmt.Errorf("failed to parse YAML: %w", err)
}
default:
return nil, fmt.Errorf("unsupported format: %s", format)
}
return &data, nil
}
func readConversationFile(file, format string) (*ConversationData, error) {
fileData, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %w", file, err)
}
// Auto-detect format if not specified
if format == "" {
format = detectFormat(file, fileData)
}
var data ConversationData
switch strings.ToLower(format) {
case "json":
if err := json.Unmarshal(fileData, &data); err != nil {
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
case "yaml", "yml":
if err := yaml.Unmarshal(fileData, &data); err != nil {
return nil, fmt.Errorf("failed to parse YAML: %w", err)
}
case "markdown", "md":
return nil, fmt.Errorf("markdown import for conversations is not yet implemented")
default:
return nil, fmt.Errorf("unsupported format: %s", format)
}
return &data, nil
}
func detectFormat(filename string, data []byte) string {
// First try file extension
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".json":
return "json"
case ".yaml", ".yml":
return "yaml"
case ".md", ".markdown":
return "markdown"
}
// Try to detect from content
data = bytes.TrimSpace(data)
if len(data) > 0 {
if data[0] == '{' || data[0] == '[' {
return "json"
}
if strings.HasPrefix(string(data), "---") || strings.Contains(string(data), ":") {
return "yaml"
}
}
return "json" // default fallback
}
func validateImportData(data *ImportData) error {
if data == nil {
return fmt.Errorf("import data is nil")
}
if len(data.Sessions) == 0 {
return fmt.Errorf("no sessions to import")
}
// Validate session structure
for i, sess := range data.Sessions {
if err := validateImportSessionHierarchy(sess, ""); err != nil {
return fmt.Errorf("session %d validation failed: %w", i, err)
}
}
return nil
}
func validateConversationData(data *ConversationData) error {
if data == nil {
return fmt.Errorf("conversation data is nil")
}
if data.Session.Title == "" {
return fmt.Errorf("session title is required")
}
if len(data.Messages) == 0 {
return fmt.Errorf("no messages to import")
}
return nil
}
func validateImportSessionHierarchy(sess ImportSession, expectedParent string) error {
if sess.ID == "" {
return fmt.Errorf("session ID is required")
}
if sess.Title == "" {
return fmt.Errorf("session title is required")
}
// For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
if expectedParent == "" {
if sess.ParentSessionID != "" {
return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
}
} else {
// For child sessions, parent should match expected parent
if sess.ParentSessionID != expectedParent {
return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
}
}
// Validate children
for _, child := range sess.Children {
if err := validateImportSessionHierarchy(child, sess.ID); err != nil {
return err
}
}
return nil
}
func validateSessionHierarchy(sess SessionWithChildren, expectedParent string) error {
if sess.ID == "" {
return fmt.Errorf("session ID is required")
}
if sess.Title == "" {
return fmt.Errorf("session title is required")
}
// For top-level sessions, expectedParent should be empty and session should have no parent or empty parent
if expectedParent == "" {
if sess.ParentSessionID != "" {
return fmt.Errorf("top-level session should not have a parent, got %s", sess.ParentSessionID)
}
} else {
// For child sessions, parent should match expected parent
if sess.ParentSessionID != expectedParent {
return fmt.Errorf("parent session ID mismatch: expected %s, got %s (session ID: %s)", expectedParent, sess.ParentSessionID, sess.ID)
}
}
// Validate children
for _, child := range sess.Children {
if err := validateSessionHierarchy(child, sess.ID); err != nil {
return err
}
}
return nil
}
func countTotalImportSessions(sessions []ImportSession) int {
count := len(sessions)
for _, sess := range sessions {
count += countTotalImportSessions(sess.Children)
}
return count
}
func countTotalSessions(sessions []SessionWithChildren) int {
count := len(sessions)
for _, sess := range sessions {
count += countTotalSessions(sess.Children)
}
return count
}
func importSessions(ctx context.Context, sessionService session.Service, messageService message.Service, data *ImportData) (ImportResult, error) {
result := ImportResult{
TotalSessions: countTotalImportSessions(data.Sessions),
SessionMapping: make(map[string]string),
}
// Import sessions recursively, starting with top-level sessions
for _, sess := range data.Sessions {
err := importImportSessionWithChildren(ctx, sessionService, messageService, sess, "", &result)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("failed to import session %s: %v", sess.ID, err))
}
}
return result, nil
}
func importConversation(ctx context.Context, sessionService session.Service, messageService message.Service, data *ConversationData) (string, int, error) {
// Generate new session ID
newSessionID := uuid.New().String()
// Create the session using the low-level database API
cwd, err := getCwd()
if err != nil {
return "", 0, err
}
cfg, err := config.Init(cwd, false)
if err != nil {
return "", 0, err
}
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
if err != nil {
return "", 0, err
}
queries := db.New(conn)
// Create session with all original metadata
_, err = queries.CreateSession(ctx, db.CreateSessionParams{
ID: newSessionID,
ParentSessionID: sql.NullString{Valid: false},
Title: data.Session.Title,
MessageCount: data.Session.MessageCount,
PromptTokens: data.Session.PromptTokens,
CompletionTokens: data.Session.CompletionTokens,
Cost: data.Session.Cost,
})
if err != nil {
return "", 0, fmt.Errorf("failed to create session: %w", err)
}
// Import messages
messageCount := 0
for _, msg := range data.Messages {
// Generate new message ID
newMessageID := uuid.New().String()
// Marshal message parts
partsJSON, err := json.Marshal(msg.Parts)
if err != nil {
return "", 0, fmt.Errorf("failed to marshal message parts: %w", err)
}
// Create message
_, err = queries.CreateMessage(ctx, db.CreateMessageParams{
ID: newMessageID,
SessionID: newSessionID,
Role: string(msg.Role),
Parts: string(partsJSON),
Model: sql.NullString{String: msg.Model, Valid: msg.Model != ""},
Provider: sql.NullString{String: msg.Provider, Valid: msg.Provider != ""},
})
if err != nil {
return "", 0, fmt.Errorf("failed to create message: %w", err)
}
messageCount++
}
return newSessionID, messageCount, nil
}
func importImportSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess ImportSession, parentID string, result *ImportResult) error {
// Generate new session ID
newSessionID := uuid.New().String()
result.SessionMapping[sess.ID] = newSessionID
// Create the session using the low-level database API to preserve metadata
cwd, err := getCwd()
if err != nil {
return err
}
cfg, err := config.Init(cwd, false)
if err != nil {
return err
}
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
if err != nil {
return err
}
queries := db.New(conn)
// Create session with all original metadata
parentSessionID := sql.NullString{Valid: false}
if parentID != "" {
parentSessionID = sql.NullString{String: parentID, Valid: true}
}
_, err = queries.CreateSession(ctx, db.CreateSessionParams{
ID: newSessionID,
ParentSessionID: parentSessionID,
Title: sess.Title,
MessageCount: sess.MessageCount,
PromptTokens: sess.PromptTokens,
CompletionTokens: sess.CompletionTokens,
Cost: sess.Cost,
})
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
result.ImportedSessions++
// Import children recursively
for _, child := range sess.Children {
err := importImportSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
}
}
return nil
}
func importSessionWithChildren(ctx context.Context, sessionService session.Service, messageService message.Service, sess SessionWithChildren, parentID string, result *ImportResult) error {
// Generate new session ID
newSessionID := uuid.New().String()
result.SessionMapping[sess.ID] = newSessionID
// Create the session using the low-level database API to preserve metadata
cwd, err := getCwd()
if err != nil {
return err
}
cfg, err := config.Init(cwd, false)
if err != nil {
return err
}
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
if err != nil {
return err
}
queries := db.New(conn)
// Create session with all original metadata
parentSessionID := sql.NullString{Valid: false}
if parentID != "" {
parentSessionID = sql.NullString{String: parentID, Valid: true}
}
_, err = queries.CreateSession(ctx, db.CreateSessionParams{
ID: newSessionID,
ParentSessionID: parentSessionID,
Title: sess.Title,
MessageCount: sess.MessageCount,
PromptTokens: sess.PromptTokens,
CompletionTokens: sess.CompletionTokens,
Cost: sess.Cost,
})
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
result.ImportedSessions++
// Import children recursively
for _, child := range sess.Children {
err := importSessionWithChildren(ctx, sessionService, messageService, child, newSessionID, result)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("failed to import child session %s: %v", child.ID, err))
}
}
return nil
}
func runSessionsSearch(ctx context.Context, titlePattern, textPattern, format string) error {
sessionService, err := createSessionService(ctx)
if err != nil {
return err
}
var sessions []session.Session
// Determine which search method to use based on provided patterns
if titlePattern != "" && textPattern != "" {
sessions, err = sessionService.SearchByTitleAndText(ctx, titlePattern, textPattern)
} else if titlePattern != "" {
sessions, err = sessionService.SearchByTitle(ctx, titlePattern)
} else if textPattern != "" {
sessions, err = sessionService.SearchByText(ctx, textPattern)
}
if err != nil {
return fmt.Errorf("search failed: %w", err)
}
return formatSearchResults(sessions, format)
}
func formatSearchResults(sessions []session.Session, format string) error {
switch strings.ToLower(format) {
case "json":
return formatSearchResultsJSON(sessions)
case "text":
return formatSearchResultsText(sessions)
default:
return fmt.Errorf("unsupported format: %s", format)
}
}
func formatSearchResultsJSON(sessions []session.Session) error {
data, err := json.MarshalIndent(sessions, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal JSON: %w", err)
}
fmt.Println(string(data))
return nil
}
func formatSearchResultsText(sessions []session.Session) error {
if len(sessions) == 0 {
fmt.Println("No sessions found matching the search criteria.")
return nil
}
fmt.Printf("Found %d session(s):\n\n", len(sessions))
for _, sess := range sessions {
fmt.Printf("• %s (ID: %s)\n", sess.Title, sess.ID)
fmt.Printf(" Messages: %d, Cost: $%.4f\n", sess.MessageCount, sess.Cost)
fmt.Printf(" Created: %s\n", formatTimestamp(sess.CreatedAt))
if sess.ParentSessionID != "" {
fmt.Printf(" Parent: %s\n", sess.ParentSessionID)
}
fmt.Println()
}
return nil
}
func runSessionsStats(ctx context.Context, format, groupBy string) error {
// Get database connection
cwd, err := getCwd()
if err != nil {
return err
}
cfg, err := config.Init(cwd, false)
if err != nil {
return err
}
conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
if err != nil {
return err
}
queries := db.New(conn)
// Handle grouped statistics
if groupBy != "" {
return runGroupedStats(ctx, queries, format, groupBy)
}
// Get overall statistics
statsRow, err := queries.GetSessionStats(ctx)
if err != nil {
return fmt.Errorf("failed to get session stats: %w", err)
}
// Convert to our struct, handling NULL values
stats := SessionStats{
TotalSessions: statsRow.TotalSessions,
TotalMessages: convertNullFloat64ToInt64(statsRow.TotalMessages),
TotalPromptTokens: convertNullFloat64ToInt64(statsRow.TotalPromptTokens),
TotalCompletionTokens: convertNullFloat64ToInt64(statsRow.TotalCompletionTokens),
TotalCost: convertNullFloat64(statsRow.TotalCost),
AvgCostPerSession: convertNullFloat64(statsRow.AvgCostPerSession),
}
return formatStats(stats, format)
}
func runGroupedStats(ctx context.Context, queries *db.Queries, format, groupBy string) error {
var groupedStats []GroupedSessionStats
switch strings.ToLower(groupBy) {
case "day":
rows, err := queries.GetSessionStatsByDay(ctx)
if err != nil {
return fmt.Errorf("failed to get daily stats: %w", err)
}
groupedStats = convertDayStatsRows(rows)
case "week":
rows, err := queries.GetSessionStatsByWeek(ctx)
if err != nil {
return fmt.Errorf("failed to get weekly stats: %w", err)
}
groupedStats = convertWeekStatsRows(rows)
case "month":
rows, err := queries.GetSessionStatsByMonth(ctx)
if err != nil {
return fmt.Errorf("failed to get monthly stats: %w", err)
}
groupedStats = convertMonthStatsRows(rows)
default:
return fmt.Errorf("unsupported group-by value: %s. Valid values are: day, week, month", groupBy)
}
return formatGroupedStats(groupedStats, format, groupBy)
}
func convertNullFloat64(val sql.NullFloat64) float64 {
if val.Valid {
return val.Float64
}
return 0.0
}
func convertNullFloat64ToInt64(val sql.NullFloat64) int64 {
if val.Valid {
return int64(val.Float64)
}
return 0
}
func convertDayStatsRows(rows []db.GetSessionStatsByDayRow) []GroupedSessionStats {
result := make([]GroupedSessionStats, 0, len(rows))
for _, row := range rows {
stats := GroupedSessionStats{
Period: fmt.Sprintf("%v", row.Day),
SessionCount: row.SessionCount,
MessageCount: convertNullFloat64ToInt64(row.MessageCount),
PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
TotalCost: convertNullFloat64(row.TotalCost),
AvgCost: convertNullFloat64(row.AvgCost),
}
result = append(result, stats)
}
return result
}
func convertWeekStatsRows(rows []db.GetSessionStatsByWeekRow) []GroupedSessionStats {
result := make([]GroupedSessionStats, 0, len(rows))
for _, row := range rows {
stats := GroupedSessionStats{
Period: fmt.Sprintf("%v", row.WeekStart),
SessionCount: row.SessionCount,
MessageCount: convertNullFloat64ToInt64(row.MessageCount),
PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
TotalCost: convertNullFloat64(row.TotalCost),
AvgCost: convertNullFloat64(row.AvgCost),
}
result = append(result, stats)
}
return result
}
func convertMonthStatsRows(rows []db.GetSessionStatsByMonthRow) []GroupedSessionStats {
result := make([]GroupedSessionStats, 0, len(rows))
for _, row := range rows {
stats := GroupedSessionStats{
Period: fmt.Sprintf("%v", row.Month),
SessionCount: row.SessionCount,
MessageCount: convertNullFloat64ToInt64(row.MessageCount),
PromptTokens: convertNullFloat64ToInt64(row.PromptTokens),
CompletionTokens: convertNullFloat64ToInt64(row.CompletionTokens),
TotalCost: convertNullFloat64(row.TotalCost),
AvgCost: convertNullFloat64(row.AvgCost),
}
result = append(result, stats)
}
return result
}
func formatStats(stats SessionStats, format string) error {
switch strings.ToLower(format) {
case "json":
return formatStatsJSON(stats)
case "text":
return formatStatsText(stats)
default:
return fmt.Errorf("unsupported format: %s", format)
}
}
func formatGroupedStats(stats []GroupedSessionStats, format, groupBy string) error {
switch strings.ToLower(format) {
case "json":
return formatGroupedStatsJSON(stats)
case "text":
return formatGroupedStatsText(stats, groupBy)
default:
return fmt.Errorf("unsupported format: %s", format)
}
}
func formatStatsJSON(stats SessionStats) error {
data, err := json.MarshalIndent(stats, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal JSON: %w", err)
}
fmt.Println(string(data))
return nil
}
func formatStatsText(stats SessionStats) error {
if stats.TotalSessions == 0 {
fmt.Println("No sessions found.")
return nil
}
fmt.Println("Session Statistics")
fmt.Println("==================")
fmt.Printf("Total Sessions: %d\n", stats.TotalSessions)
fmt.Printf("Total Messages: %d\n", stats.TotalMessages)
fmt.Printf("Total Prompt Tokens: %d\n", stats.TotalPromptTokens)
fmt.Printf("Total Completion Tokens: %d\n", stats.TotalCompletionTokens)
fmt.Printf("Total Cost: $%.4f\n", stats.TotalCost)
fmt.Printf("Average Cost/Session: $%.4f\n", stats.AvgCostPerSession)
totalTokens := stats.TotalPromptTokens + stats.TotalCompletionTokens
if totalTokens > 0 {
fmt.Printf("Total Tokens: %d\n", totalTokens)
fmt.Printf("Average Tokens/Session: %.1f\n", float64(totalTokens)/float64(stats.TotalSessions))
}
if stats.TotalSessions > 0 {
fmt.Printf("Average Messages/Session: %.1f\n", float64(stats.TotalMessages)/float64(stats.TotalSessions))
}
return nil
}
func formatGroupedStatsJSON(stats []GroupedSessionStats) error {
data, err := json.MarshalIndent(stats, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal JSON: %w", err)
}
fmt.Println(string(data))
return nil
}
func formatGroupedStatsText(stats []GroupedSessionStats, groupBy string) error {
if len(stats) == 0 {
fmt.Printf("No sessions found for grouping by %s.\n", groupBy)
return nil
}
fmt.Printf("Session Statistics (Grouped by %s)\n", strings.ToUpper(groupBy[:1])+groupBy[1:])
fmt.Println(strings.Repeat("=", 30+len(groupBy)))
fmt.Println()
for _, stat := range stats {
fmt.Printf("Period: %s\n", stat.Period)
fmt.Printf(" Sessions: %d\n", stat.SessionCount)
fmt.Printf(" Messages: %d\n", stat.MessageCount)
fmt.Printf(" Prompt Tokens: %d\n", stat.PromptTokens)
fmt.Printf(" Completion Tokens: %d\n", stat.CompletionTokens)
fmt.Printf(" Total Cost: $%.4f\n", stat.TotalCost)
fmt.Printf(" Average Cost: $%.4f\n", stat.AvgCost)
totalTokens := stat.PromptTokens + stat.CompletionTokens
if totalTokens > 0 {
fmt.Printf(" Total Tokens: %d\n", totalTokens)
}
fmt.Println()
}
return nil
}

View File

@@ -60,6 +60,18 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil {
return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err)
}
if q.getSessionStatsStmt, err = db.PrepareContext(ctx, getSessionStats); err != nil {
return nil, fmt.Errorf("error preparing query GetSessionStats: %w", err)
}
if q.getSessionStatsByDayStmt, err = db.PrepareContext(ctx, getSessionStatsByDay); err != nil {
return nil, fmt.Errorf("error preparing query GetSessionStatsByDay: %w", err)
}
if q.getSessionStatsByMonthStmt, err = db.PrepareContext(ctx, getSessionStatsByMonth); err != nil {
return nil, fmt.Errorf("error preparing query GetSessionStatsByMonth: %w", err)
}
if q.getSessionStatsByWeekStmt, err = db.PrepareContext(ctx, getSessionStatsByWeek); err != nil {
return nil, fmt.Errorf("error preparing query GetSessionStatsByWeek: %w", err)
}
if q.listAllSessionsStmt, err = db.PrepareContext(ctx, listAllSessions); err != nil {
return nil, fmt.Errorf("error preparing query ListAllSessions: %w", err)
}
@@ -84,6 +96,15 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
}
if q.searchSessionsByTextStmt, err = db.PrepareContext(ctx, searchSessionsByText); err != nil {
return nil, fmt.Errorf("error preparing query SearchSessionsByText: %w", err)
}
if q.searchSessionsByTitleStmt, err = db.PrepareContext(ctx, searchSessionsByTitle); err != nil {
return nil, fmt.Errorf("error preparing query SearchSessionsByTitle: %w", err)
}
if q.searchSessionsByTitleAndTextStmt, err = db.PrepareContext(ctx, searchSessionsByTitleAndText); err != nil {
return nil, fmt.Errorf("error preparing query SearchSessionsByTitleAndText: %w", err)
}
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
}
@@ -155,6 +176,26 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr)
}
}
if q.getSessionStatsStmt != nil {
if cerr := q.getSessionStatsStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing getSessionStatsStmt: %w", cerr)
}
}
if q.getSessionStatsByDayStmt != nil {
if cerr := q.getSessionStatsByDayStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing getSessionStatsByDayStmt: %w", cerr)
}
}
if q.getSessionStatsByMonthStmt != nil {
if cerr := q.getSessionStatsByMonthStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing getSessionStatsByMonthStmt: %w", cerr)
}
}
if q.getSessionStatsByWeekStmt != nil {
if cerr := q.getSessionStatsByWeekStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing getSessionStatsByWeekStmt: %w", cerr)
}
}
if q.listAllSessionsStmt != nil {
if cerr := q.listAllSessionsStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing listAllSessionsStmt: %w", cerr)
@@ -195,6 +236,21 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
}
}
if q.searchSessionsByTextStmt != nil {
if cerr := q.searchSessionsByTextStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing searchSessionsByTextStmt: %w", cerr)
}
}
if q.searchSessionsByTitleStmt != nil {
if cerr := q.searchSessionsByTitleStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing searchSessionsByTitleStmt: %w", cerr)
}
}
if q.searchSessionsByTitleAndTextStmt != nil {
if cerr := q.searchSessionsByTitleAndTextStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing searchSessionsByTitleAndTextStmt: %w", cerr)
}
}
if q.updateMessageStmt != nil {
if cerr := q.updateMessageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
@@ -242,57 +298,71 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
}
type Queries struct {
db DBTX
tx *sql.Tx
createFileStmt *sql.Stmt
createMessageStmt *sql.Stmt
createSessionStmt *sql.Stmt
deleteFileStmt *sql.Stmt
deleteMessageStmt *sql.Stmt
deleteSessionStmt *sql.Stmt
deleteSessionFilesStmt *sql.Stmt
deleteSessionMessagesStmt *sql.Stmt
getFileStmt *sql.Stmt
getFileByPathAndSessionStmt *sql.Stmt
getMessageStmt *sql.Stmt
getSessionByIDStmt *sql.Stmt
listAllSessionsStmt *sql.Stmt
listChildSessionsStmt *sql.Stmt
listFilesByPathStmt *sql.Stmt
listFilesBySessionStmt *sql.Stmt
listLatestSessionFilesStmt *sql.Stmt
listMessagesBySessionStmt *sql.Stmt
listNewFilesStmt *sql.Stmt
listSessionsStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
db DBTX
tx *sql.Tx
createFileStmt *sql.Stmt
createMessageStmt *sql.Stmt
createSessionStmt *sql.Stmt
deleteFileStmt *sql.Stmt
deleteMessageStmt *sql.Stmt
deleteSessionStmt *sql.Stmt
deleteSessionFilesStmt *sql.Stmt
deleteSessionMessagesStmt *sql.Stmt
getFileStmt *sql.Stmt
getFileByPathAndSessionStmt *sql.Stmt
getMessageStmt *sql.Stmt
getSessionByIDStmt *sql.Stmt
getSessionStatsStmt *sql.Stmt
getSessionStatsByDayStmt *sql.Stmt
getSessionStatsByMonthStmt *sql.Stmt
getSessionStatsByWeekStmt *sql.Stmt
listAllSessionsStmt *sql.Stmt
listChildSessionsStmt *sql.Stmt
listFilesByPathStmt *sql.Stmt
listFilesBySessionStmt *sql.Stmt
listLatestSessionFilesStmt *sql.Stmt
listMessagesBySessionStmt *sql.Stmt
listNewFilesStmt *sql.Stmt
listSessionsStmt *sql.Stmt
searchSessionsByTextStmt *sql.Stmt
searchSessionsByTitleStmt *sql.Stmt
searchSessionsByTitleAndTextStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
tx: tx,
createFileStmt: q.createFileStmt,
createMessageStmt: q.createMessageStmt,
createSessionStmt: q.createSessionStmt,
deleteFileStmt: q.deleteFileStmt,
deleteMessageStmt: q.deleteMessageStmt,
deleteSessionStmt: q.deleteSessionStmt,
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
getFileStmt: q.getFileStmt,
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
getMessageStmt: q.getMessageStmt,
getSessionByIDStmt: q.getSessionByIDStmt,
listAllSessionsStmt: q.listAllSessionsStmt,
listChildSessionsStmt: q.listChildSessionsStmt,
listFilesByPathStmt: q.listFilesByPathStmt,
listFilesBySessionStmt: q.listFilesBySessionStmt,
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
listNewFilesStmt: q.listNewFilesStmt,
listSessionsStmt: q.listSessionsStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
db: tx,
tx: tx,
createFileStmt: q.createFileStmt,
createMessageStmt: q.createMessageStmt,
createSessionStmt: q.createSessionStmt,
deleteFileStmt: q.deleteFileStmt,
deleteMessageStmt: q.deleteMessageStmt,
deleteSessionStmt: q.deleteSessionStmt,
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
getFileStmt: q.getFileStmt,
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
getMessageStmt: q.getMessageStmt,
getSessionByIDStmt: q.getSessionByIDStmt,
getSessionStatsStmt: q.getSessionStatsStmt,
getSessionStatsByDayStmt: q.getSessionStatsByDayStmt,
getSessionStatsByMonthStmt: q.getSessionStatsByMonthStmt,
getSessionStatsByWeekStmt: q.getSessionStatsByWeekStmt,
listAllSessionsStmt: q.listAllSessionsStmt,
listChildSessionsStmt: q.listChildSessionsStmt,
listFilesByPathStmt: q.listFilesByPathStmt,
listFilesBySessionStmt: q.listFilesBySessionStmt,
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
listNewFilesStmt: q.listNewFilesStmt,
listSessionsStmt: q.listSessionsStmt,
searchSessionsByTextStmt: q.searchSessionsByTextStmt,
searchSessionsByTitleStmt: q.searchSessionsByTitleStmt,
searchSessionsByTitleAndTextStmt: q.searchSessionsByTitleAndTextStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
}
}

View File

@@ -22,6 +22,10 @@ type Querier interface {
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
GetMessage(ctx context.Context, id string) (Message, error)
GetSessionByID(ctx context.Context, id string) (Session, error)
GetSessionStats(ctx context.Context) (GetSessionStatsRow, error)
GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error)
GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error)
GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error)
ListAllSessions(ctx context.Context) ([]Session, error)
ListChildSessions(ctx context.Context, parentSessionID sql.NullString) ([]Session, error)
ListFilesByPath(ctx context.Context, path string) ([]File, error)
@@ -30,6 +34,9 @@ type Querier interface {
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
ListNewFiles(ctx context.Context) ([]File, error)
ListSessions(ctx context.Context) ([]Session, error)
SearchSessionsByText(ctx context.Context, parts string) ([]Session, error)
SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error)
SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error)
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
}

View File

@@ -106,6 +106,205 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
return i, err
}
const getSessionStats = `-- name: GetSessionStats :one
SELECT
COUNT(*) as total_sessions,
SUM(message_count) as total_messages,
SUM(prompt_tokens) as total_prompt_tokens,
SUM(completion_tokens) as total_completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost_per_session
FROM sessions
`
type GetSessionStatsRow struct {
TotalSessions int64 `json:"total_sessions"`
TotalMessages sql.NullFloat64 `json:"total_messages"`
TotalPromptTokens sql.NullFloat64 `json:"total_prompt_tokens"`
TotalCompletionTokens sql.NullFloat64 `json:"total_completion_tokens"`
TotalCost sql.NullFloat64 `json:"total_cost"`
AvgCostPerSession sql.NullFloat64 `json:"avg_cost_per_session"`
}
func (q *Queries) GetSessionStats(ctx context.Context) (GetSessionStatsRow, error) {
row := q.queryRow(ctx, q.getSessionStatsStmt, getSessionStats)
var i GetSessionStatsRow
err := row.Scan(
&i.TotalSessions,
&i.TotalMessages,
&i.TotalPromptTokens,
&i.TotalCompletionTokens,
&i.TotalCost,
&i.AvgCostPerSession,
)
return i, err
}
const getSessionStatsByDay = `-- name: GetSessionStatsByDay :many
SELECT
date(created_at, 'unixepoch') as day,
COUNT(*) as session_count,
SUM(message_count) as message_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost
FROM sessions
GROUP BY date(created_at, 'unixepoch')
ORDER BY day DESC
`
type GetSessionStatsByDayRow struct {
Day interface{} `json:"day"`
SessionCount int64 `json:"session_count"`
MessageCount sql.NullFloat64 `json:"message_count"`
PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
TotalCost sql.NullFloat64 `json:"total_cost"`
AvgCost sql.NullFloat64 `json:"avg_cost"`
}
func (q *Queries) GetSessionStatsByDay(ctx context.Context) ([]GetSessionStatsByDayRow, error) {
rows, err := q.query(ctx, q.getSessionStatsByDayStmt, getSessionStatsByDay)
if err != nil {
return nil, err
}
defer rows.Close()
items := []GetSessionStatsByDayRow{}
for rows.Next() {
var i GetSessionStatsByDayRow
if err := rows.Scan(
&i.Day,
&i.SessionCount,
&i.MessageCount,
&i.PromptTokens,
&i.CompletionTokens,
&i.TotalCost,
&i.AvgCost,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getSessionStatsByMonth = `-- name: GetSessionStatsByMonth :many
SELECT
strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month,
COUNT(*) as session_count,
SUM(message_count) as message_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost
FROM sessions
GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch'))
ORDER BY month DESC
`
type GetSessionStatsByMonthRow struct {
Month interface{} `json:"month"`
SessionCount int64 `json:"session_count"`
MessageCount sql.NullFloat64 `json:"message_count"`
PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
TotalCost sql.NullFloat64 `json:"total_cost"`
AvgCost sql.NullFloat64 `json:"avg_cost"`
}
func (q *Queries) GetSessionStatsByMonth(ctx context.Context) ([]GetSessionStatsByMonthRow, error) {
rows, err := q.query(ctx, q.getSessionStatsByMonthStmt, getSessionStatsByMonth)
if err != nil {
return nil, err
}
defer rows.Close()
items := []GetSessionStatsByMonthRow{}
for rows.Next() {
var i GetSessionStatsByMonthRow
if err := rows.Scan(
&i.Month,
&i.SessionCount,
&i.MessageCount,
&i.PromptTokens,
&i.CompletionTokens,
&i.TotalCost,
&i.AvgCost,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getSessionStatsByWeek = `-- name: GetSessionStatsByWeek :many
SELECT
date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start,
COUNT(*) as session_count,
SUM(message_count) as message_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost
FROM sessions
GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days')
ORDER BY week_start DESC
`
type GetSessionStatsByWeekRow struct {
WeekStart interface{} `json:"week_start"`
SessionCount int64 `json:"session_count"`
MessageCount sql.NullFloat64 `json:"message_count"`
PromptTokens sql.NullFloat64 `json:"prompt_tokens"`
CompletionTokens sql.NullFloat64 `json:"completion_tokens"`
TotalCost sql.NullFloat64 `json:"total_cost"`
AvgCost sql.NullFloat64 `json:"avg_cost"`
}
func (q *Queries) GetSessionStatsByWeek(ctx context.Context) ([]GetSessionStatsByWeekRow, error) {
rows, err := q.query(ctx, q.getSessionStatsByWeekStmt, getSessionStatsByWeek)
if err != nil {
return nil, err
}
defer rows.Close()
items := []GetSessionStatsByWeekRow{}
for rows.Next() {
var i GetSessionStatsByWeekRow
if err := rows.Scan(
&i.WeekStart,
&i.SessionCount,
&i.MessageCount,
&i.PromptTokens,
&i.CompletionTokens,
&i.TotalCost,
&i.AvgCost,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAllSessions = `-- name: ListAllSessions :many
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
FROM sessions
@@ -228,6 +427,136 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
return items, nil
}
const searchSessionsByText = `-- name: SearchSessionsByText :many
SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id
FROM sessions s
JOIN messages m ON s.id = m.session_id
WHERE m.parts LIKE ?
ORDER BY s.created_at DESC
`
func (q *Queries) SearchSessionsByText(ctx context.Context, parts string) ([]Session, error) {
rows, err := q.query(ctx, q.searchSessionsByTextStmt, searchSessionsByText, parts)
if err != nil {
return nil, err
}
defer rows.Close()
items := []Session{}
for rows.Next() {
var i Session
if err := rows.Scan(
&i.ID,
&i.ParentSessionID,
&i.Title,
&i.MessageCount,
&i.PromptTokens,
&i.CompletionTokens,
&i.Cost,
&i.UpdatedAt,
&i.CreatedAt,
&i.SummaryMessageID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const searchSessionsByTitle = `-- name: SearchSessionsByTitle :many
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
FROM sessions
WHERE title LIKE ?
ORDER BY created_at DESC
`
func (q *Queries) SearchSessionsByTitle(ctx context.Context, title string) ([]Session, error) {
rows, err := q.query(ctx, q.searchSessionsByTitleStmt, searchSessionsByTitle, title)
if err != nil {
return nil, err
}
defer rows.Close()
items := []Session{}
for rows.Next() {
var i Session
if err := rows.Scan(
&i.ID,
&i.ParentSessionID,
&i.Title,
&i.MessageCount,
&i.PromptTokens,
&i.CompletionTokens,
&i.Cost,
&i.UpdatedAt,
&i.CreatedAt,
&i.SummaryMessageID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const searchSessionsByTitleAndText = `-- name: SearchSessionsByTitleAndText :many
SELECT DISTINCT s.id, s.parent_session_id, s.title, s.message_count, s.prompt_tokens, s.completion_tokens, s.cost, s.updated_at, s.created_at, s.summary_message_id
FROM sessions s
JOIN messages m ON s.id = m.session_id
WHERE s.title LIKE ? AND m.parts LIKE ?
ORDER BY s.created_at DESC
`
type SearchSessionsByTitleAndTextParams struct {
Title string `json:"title"`
Parts string `json:"parts"`
}
func (q *Queries) SearchSessionsByTitleAndText(ctx context.Context, arg SearchSessionsByTitleAndTextParams) ([]Session, error) {
rows, err := q.query(ctx, q.searchSessionsByTitleAndTextStmt, searchSessionsByTitleAndText, arg.Title, arg.Parts)
if err != nil {
return nil, err
}
defer rows.Close()
items := []Session{}
for rows.Next() {
var i Session
if err := rows.Scan(
&i.ID,
&i.ParentSessionID,
&i.Title,
&i.MessageCount,
&i.PromptTokens,
&i.CompletionTokens,
&i.Cost,
&i.UpdatedAt,
&i.CreatedAt,
&i.SummaryMessageID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateSession = `-- name: UpdateSession :one
UPDATE sessions
SET

View File

@@ -60,3 +60,72 @@ ORDER BY created_at ASC;
SELECT *
FROM sessions
ORDER BY created_at DESC;
-- name: SearchSessionsByTitle :many
SELECT *
FROM sessions
WHERE title LIKE ?
ORDER BY created_at DESC;
-- name: SearchSessionsByTitleAndText :many
SELECT DISTINCT s.*
FROM sessions s
JOIN messages m ON s.id = m.session_id
WHERE s.title LIKE ? AND m.parts LIKE ?
ORDER BY s.created_at DESC;
-- name: SearchSessionsByText :many
SELECT DISTINCT s.*
FROM sessions s
JOIN messages m ON s.id = m.session_id
WHERE m.parts LIKE ?
ORDER BY s.created_at DESC;
-- name: GetSessionStats :one
SELECT
COUNT(*) as total_sessions,
SUM(message_count) as total_messages,
SUM(prompt_tokens) as total_prompt_tokens,
SUM(completion_tokens) as total_completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost_per_session
FROM sessions;
-- name: GetSessionStatsByDay :many
SELECT
date(created_at, 'unixepoch') as day,
COUNT(*) as session_count,
SUM(message_count) as message_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost
FROM sessions
GROUP BY date(created_at, 'unixepoch')
ORDER BY day DESC;
-- name: GetSessionStatsByWeek :many
SELECT
date(created_at, 'unixepoch', 'weekday 0', '-6 days') as week_start,
COUNT(*) as session_count,
SUM(message_count) as message_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost
FROM sessions
GROUP BY date(created_at, 'unixepoch', 'weekday 0', '-6 days')
ORDER BY week_start DESC;
-- name: GetSessionStatsByMonth :many
SELECT
strftime('%Y-%m', datetime(created_at, 'unixepoch')) as month,
COUNT(*) as session_count,
SUM(message_count) as message_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(cost) as total_cost,
AVG(cost) as avg_cost
FROM sessions
GROUP BY strftime('%Y-%m', datetime(created_at, 'unixepoch'))
ORDER BY month DESC;

View File

@@ -33,6 +33,9 @@ type Service interface {
ListChildren(ctx context.Context, parentSessionID string) ([]Session, error)
Save(ctx context.Context, session Session) (Session, error)
Delete(ctx context.Context, id string) error
SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error)
SearchByText(ctx context.Context, textPattern string) ([]Session, error)
SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error)
}
type service struct {
@@ -161,6 +164,45 @@ func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]S
return sessions, nil
}
func (s *service) SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error) {
dbSessions, err := s.q.SearchSessionsByTitle(ctx, "%"+titlePattern+"%")
if err != nil {
return nil, err
}
sessions := make([]Session, len(dbSessions))
for i, dbSession := range dbSessions {
sessions[i] = s.fromDBItem(dbSession)
}
return sessions, nil
}
func (s *service) SearchByText(ctx context.Context, textPattern string) ([]Session, error) {
dbSessions, err := s.q.SearchSessionsByText(ctx, "%"+textPattern+"%")
if err != nil {
return nil, err
}
sessions := make([]Session, len(dbSessions))
for i, dbSession := range dbSessions {
sessions[i] = s.fromDBItem(dbSession)
}
return sessions, nil
}
func (s *service) SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error) {
dbSessions, err := s.q.SearchSessionsByTitleAndText(ctx, db.SearchSessionsByTitleAndTextParams{
Title: "%" + titlePattern + "%",
Parts: "%" + textPattern + "%",
})
if err != nil {
return nil, err
}
sessions := make([]Session, len(dbSessions))
for i, dbSession := range dbSessions {
sessions[i] = s.fromDBItem(dbSession)
}
return sessions, nil
}
func (s service) fromDBItem(item db.Session) Session {
return Session{
ID: item.ID,

View File

@@ -0,0 +1,297 @@
# RFC: Session Import and Export
## Summary
This RFC proposes a comprehensive system for importing and exporting conversation sessions in Crush.
## Background
Crush manages conversations through a hierarchical session system where:
- Sessions contain metadata (title, token counts, cost, timestamps)
- Sessions can have parent-child relationships (nested conversations)
- Messages within sessions have structured content parts (text, tool calls, reasoning, etc.)
- The current implementation provides export functionality but lacks import capabilities
The latest commit introduced three key commands:
- `crush sessions list` - List sessions in various formats
- `crush sessions export` - Export all sessions and metadata
- `crush sessions export-conversation <session-id>` - Export a single conversation with messages
## Motivation
Users need to:
1. Share conversations with others
2. Use conversation logs for debugging
3. Archive and analyze conversation history
4. Export data for external tools
## Detailed Design
### Core Data Model
The session export format builds on the existing session structure:
```go
type Session struct {
ID string `json:"id"`
ParentSessionID string `json:"parent_session_id,omitempty"`
Title string `json:"title"`
MessageCount int64 `json:"message_count"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
SummaryMessageID string `json:"summary_message_id,omitempty"`
}
type SessionWithChildren struct {
Session
Children []SessionWithChildren `json:"children,omitempty"`
}
```
### Proposed Command Interface
#### Export Commands (Already Implemented)
```bash
# List sessions in various formats
crush sessions list [--format text|json|yaml|markdown]
# Export all sessions with metadata
crush sessions export [--format json|yaml|markdown]
# Export single conversation with full message history
crush sessions export-conversation <session-id> [--format markdown|json|yaml]
```
#### New Import Commands
```bash
# Import sessions from a file
crush sessions import <file> [--format json|yaml] [--dry-run]
# Import a single conversation
crush sessions import-conversation <file> [--format json|yaml|markdown]
```
#### Enhanced Inspection Commands
```bash
# Search sessions by criteria
crush sessions search [--title <pattern>] [--text <text>] [--format text|json]
# Show session statistics
crush sessions stats [--format text|json] [--group-by day|week|month]
# Show statistics for a single session
crush sessions stats <session-id> [--format text|json]
```
### Import/Export Formats
#### Full Export Format (JSON)
```json
{
"version": "1.0",
"exported_at": "2025-01-27T10:30:00Z",
"total_sessions": 15,
"sessions": [
{
"id": "session-123",
"parent_session_id": "",
"title": "API Design Discussion",
"message_count": 8,
"prompt_tokens": 1250,
"completion_tokens": 890,
"cost": 0.0234,
"created_at": 1706356200,
"updated_at": 1706359800,
"children": [
{
"id": "session-124",
"parent_session_id": "session-123",
"title": "Implementation Details",
"message_count": 4,
"prompt_tokens": 650,
"completion_tokens": 420,
"cost": 0.0145,
"created_at": 1706359900,
"updated_at": 1706361200
}
]
}
]
}
```
#### Conversation Export Format (JSON)
```json
{
"version": "1.0",
"session": {
"id": "session-123",
"title": "API Design Discussion",
"created_at": 1706356200,
"message_count": 3
},
"messages": [
{
"id": "msg-001",
"session_id": "session-123",
"role": "user",
"parts": [
{
"type": "text",
"data": {
"text": "Help me design a REST API for user management"
}
}
],
"created_at": 1706356200
},
{
"id": "msg-002",
"session_id": "session-123",
"role": "assistant",
"model": "gpt-4",
"provider": "openai",
"parts": [
{
"type": "text",
"data": {
"text": "I'll help you design a REST API for user management..."
}
},
{
"type": "finish",
"data": {
"reason": "stop",
"time": 1706356230
}
}
],
"created_at": 1706356220
}
]
}
```
### API Implementation
#### Import Service Interface
```go
type ImportService interface {
// Import sessions from structured data
ImportSessions(ctx context.Context, data ImportData, opts ImportOptions) (ImportResult, error)
// Import single conversation
ImportConversation(ctx context.Context, data ConversationData, opts ImportOptions) (Session, error)
// Validate import data without persisting
ValidateImport(ctx context.Context, data ImportData) (ValidationResult, error)
}
type ImportOptions struct {
ConflictStrategy ConflictStrategy // skip, merge, replace
DryRun bool
ParentSessionID string // For conversation imports
PreserveIDs bool // Whether to preserve original IDs
}
type ConflictStrategy string
const (
ConflictSkip ConflictStrategy = "skip" // Skip existing sessions
ConflictMerge ConflictStrategy = "merge" // Merge with existing
ConflictReplace ConflictStrategy = "replace" // Replace existing
)
type ImportResult struct {
TotalSessions int `json:"total_sessions"`
ImportedSessions int `json:"imported_sessions"`
SkippedSessions int `json:"skipped_sessions"`
Errors []ImportError `json:"errors,omitempty"`
SessionMapping map[string]string `json:"session_mapping"` // old_id -> new_id
}
```
#### Enhanced Export Service
```go
type ExportService interface {
// Export sessions with filtering
ExportSessions(ctx context.Context, opts ExportOptions) ([]SessionWithChildren, error)
// Export conversation with full message history
ExportConversation(ctx context.Context, sessionID string, opts ExportOptions) (ConversationExport, error)
// Search and filter sessions
SearchSessions(ctx context.Context, criteria SearchCriteria) ([]Session, error)
// Get session statistics
GetStats(ctx context.Context, opts StatsOptions) (SessionStats, error)
}
type ExportOptions struct {
Format string // json, yaml, markdown
IncludeMessages bool // Include full message content
DateRange DateRange // Filter by date range
SessionIDs []string // Export specific sessions
}
type SearchCriteria struct {
TitlePattern string
DateRange DateRange
MinCost float64
MaxCost float64
ParentSessionID string
HasChildren *bool
}
```
## Implementation Status
The proposed session import/export functionality has been implemented as a prototype as of July 2025.
### Implemented Commands
All new commands have been added to `internal/cmd/sessions.go`:
- **Import**: `crush sessions import <file> [--format json|yaml] [--dry-run]`
- Supports hierarchical session imports with parent-child relationships
- Generates new UUIDs to avoid conflicts
- Includes validation and dry-run capabilities
- **Import Conversation**: `crush sessions import-conversation <file> [--format json|yaml]`
- Imports single conversations with full message history
- Preserves all message content parts and metadata
- **Search**: `crush sessions search [--title <pattern>] [--text <text>] [--format text|json]`
- Case-insensitive title search and message content search
- Supports combined search criteria with AND logic
- **Stats**: `crush sessions stats [--format text|json] [--group-by day|week|month]`
- Comprehensive usage statistics (sessions, messages, tokens, costs)
- Time-based grouping with efficient database queries
### Database Changes
Added new SQL queries in `internal/db/sql/sessions.sql`:
- Search queries for title and message content filtering
- Statistics aggregation queries with time-based grouping
- All queries optimized for performance with proper indexing
### Database Schema Considerations
The current schema supports the import/export functionality. Additional indexes may be needed for search performance:
```sql
-- Optimize session searches by date and cost
CREATE INDEX idx_sessions_created_at ON sessions(created_at);
CREATE INDEX idx_sessions_cost ON sessions(cost);
CREATE INDEX idx_sessions_title ON sessions(title COLLATE NOCASE);
-- Optimize message searches by session
CREATE INDEX idx_messages_session_created ON messages(session_id, created_at);
```