mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
Improve summary to keep context (#159)
* improve summary to keep context * improve loop * remove debug msg
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
// source: files.sql
|
||||
|
||||
package db
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
// source: messages.sql
|
||||
|
||||
package db
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE sessions ADD COLUMN summary_message_id TEXT;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE sessions DROP COLUMN summary_message_id;
|
||||
-- +goose StatementEnd
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
@@ -39,4 +39,5 @@ type Session struct {
|
||||
Cost float64 `json:"cost"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
SummaryMessageID sql.NullString `json:"summary_message_id"`
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.27.0
|
||||
// sqlc v1.29.0
|
||||
// source: sessions.sql
|
||||
|
||||
package db
|
||||
@@ -19,6 +19,7 @@ INSERT INTO sessions (
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cost,
|
||||
summary_message_id,
|
||||
updated_at,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -29,9 +30,10 @@ INSERT INTO sessions (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
null,
|
||||
strftime('%s', 'now'),
|
||||
strftime('%s', 'now')
|
||||
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
@@ -65,6 +67,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
||||
&i.Cost,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
&i.SummaryMessageID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -80,7 +83,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
|
||||
}
|
||||
|
||||
const getSessionByID = `-- name: GetSessionByID :one
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
|
||||
FROM sessions
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
@@ -98,12 +101,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
|
||||
&i.Cost,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
&i.SummaryMessageID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listSessions = `-- name: ListSessions :many
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
|
||||
FROM sessions
|
||||
WHERE parent_session_id is NULL
|
||||
ORDER BY created_at DESC
|
||||
@@ -128,6 +132,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
|
||||
&i.Cost,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
&i.SummaryMessageID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,17 +153,19 @@ SET
|
||||
title = ?,
|
||||
prompt_tokens = ?,
|
||||
completion_tokens = ?,
|
||||
summary_message_id = ?,
|
||||
cost = ?
|
||||
WHERE id = ?
|
||||
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
|
||||
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
|
||||
`
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
Title string `json:"title"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
SummaryMessageID sql.NullString `json:"summary_message_id"`
|
||||
Cost float64 `json:"cost"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
|
||||
@@ -166,6 +173,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
||||
arg.Title,
|
||||
arg.PromptTokens,
|
||||
arg.CompletionTokens,
|
||||
arg.SummaryMessageID,
|
||||
arg.Cost,
|
||||
arg.ID,
|
||||
)
|
||||
@@ -180,6 +188,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
||||
&i.Cost,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
&i.SummaryMessageID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ INSERT INTO sessions (
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cost,
|
||||
summary_message_id,
|
||||
updated_at,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -17,6 +18,7 @@ INSERT INTO sessions (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
null,
|
||||
strftime('%s', 'now'),
|
||||
strftime('%s', 'now')
|
||||
) RETURNING *;
|
||||
@@ -38,6 +40,7 @@ SET
|
||||
title = ?,
|
||||
prompt_tokens = ?,
|
||||
completion_tokens = ?,
|
||||
summary_message_id = ?,
|
||||
cost = ?
|
||||
WHERE id = ?
|
||||
RETURNING *;
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/opencode-ai/opencode/internal/config"
|
||||
"github.com/opencode-ai/opencode/internal/llm/models"
|
||||
@@ -245,6 +246,23 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
|
||||
}
|
||||
}()
|
||||
}
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return a.err(fmt.Errorf("failed to get session: %w", err))
|
||||
}
|
||||
if session.SummaryMessageID != "" {
|
||||
summaryMsgInex := -1
|
||||
for i, msg := range msgs {
|
||||
if msg.ID == session.SummaryMessageID {
|
||||
summaryMsgInex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if summaryMsgInex != -1 {
|
||||
msgs = msgs[summaryMsgInex:]
|
||||
msgs[0].Role = message.User
|
||||
}
|
||||
}
|
||||
|
||||
userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
|
||||
if err != nil {
|
||||
@@ -614,22 +632,16 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
|
||||
a.Publish(pubsub.CreatedEvent, event)
|
||||
return
|
||||
}
|
||||
// Create a new session with the summary
|
||||
newSession, err := a.sessions.Create(summarizeCtx, oldSession.Title+" - Continuation")
|
||||
if err != nil {
|
||||
event = AgentEvent{
|
||||
Type: AgentEventTypeError,
|
||||
Error: fmt.Errorf("failed to create new session: %w", err),
|
||||
Done: true,
|
||||
}
|
||||
a.Publish(pubsub.CreatedEvent, event)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a message in the new session with the summary
|
||||
_, err = a.messages.Create(summarizeCtx, newSession.ID, message.CreateMessageParams{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{message.TextContent{Text: summary}},
|
||||
msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{Text: summary},
|
||||
message.Finish{
|
||||
Reason: message.FinishReasonEndTurn,
|
||||
Time: time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
Model: a.summarizeProvider.Model().ID,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -642,9 +654,29 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
|
||||
a.Publish(pubsub.CreatedEvent, event)
|
||||
return
|
||||
}
|
||||
oldSession.SummaryMessageID = msg.ID
|
||||
oldSession.CompletionTokens = response.Usage.OutputTokens
|
||||
oldSession.PromptTokens = 0
|
||||
model := a.summarizeProvider.Model()
|
||||
usage := response.Usage
|
||||
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
||||
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
|
||||
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
|
||||
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
||||
oldSession.Cost += cost
|
||||
_, err = a.sessions.Save(summarizeCtx, oldSession)
|
||||
if err != nil {
|
||||
event = AgentEvent{
|
||||
Type: AgentEventTypeError,
|
||||
Error: fmt.Errorf("failed to save session: %w", err),
|
||||
Done: true,
|
||||
}
|
||||
a.Publish(pubsub.CreatedEvent, event)
|
||||
}
|
||||
|
||||
event = AgentEvent{
|
||||
Type: AgentEventTypeSummarize,
|
||||
SessionID: newSession.ID,
|
||||
SessionID: oldSession.ID,
|
||||
Progress: "Summary complete",
|
||||
Done: true,
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type Session struct {
|
||||
MessageCount int64
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
SummaryMessageID string
|
||||
Cost float64
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
@@ -105,7 +106,11 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
|
||||
Title: session.Title,
|
||||
PromptTokens: session.PromptTokens,
|
||||
CompletionTokens: session.CompletionTokens,
|
||||
Cost: session.Cost,
|
||||
SummaryMessageID: sql.NullString{
|
||||
String: session.SummaryMessageID,
|
||||
Valid: session.SummaryMessageID != "",
|
||||
},
|
||||
Cost: session.Cost,
|
||||
})
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
@@ -135,6 +140,7 @@ func (s service) fromDBItem(item db.Session) Session {
|
||||
MessageCount: item.MessageCount,
|
||||
PromptTokens: item.PromptTokens,
|
||||
CompletionTokens: item.CompletionTokens,
|
||||
SummaryMessageID: item.SummaryMessageID.String,
|
||||
Cost: item.Cost,
|
||||
CreatedAt: item.CreatedAt,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
|
||||
@@ -99,6 +99,14 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case renderFinishedMsg:
|
||||
m.rendering = false
|
||||
m.viewport.GotoBottom()
|
||||
case pubsub.Event[session.Session]:
|
||||
if msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.session.ID {
|
||||
m.session = msg.Payload
|
||||
if m.session.SummaryMessageID == m.currentMsgID {
|
||||
delete(m.cachedContent, m.currentMsgID)
|
||||
m.renderView()
|
||||
}
|
||||
}
|
||||
case pubsub.Event[message.Message]:
|
||||
needsRerender := false
|
||||
if msg.Type == pubsub.CreatedEvent {
|
||||
@@ -208,12 +216,15 @@ func (m *messagesCmp) renderView() {
|
||||
m.uiMessages = append(m.uiMessages, cache.content...)
|
||||
continue
|
||||
}
|
||||
isSummary := m.session.SummaryMessageID == msg.ID
|
||||
|
||||
assistantMessages := renderAssistantMessage(
|
||||
msg,
|
||||
inx,
|
||||
m.messages,
|
||||
m.app.Messages,
|
||||
m.currentMsgID,
|
||||
isSummary,
|
||||
m.width,
|
||||
pos,
|
||||
)
|
||||
|
||||
@@ -120,6 +120,7 @@ func renderAssistantMessage(
|
||||
allMessages []message.Message, // we need this to get tool results and the user message
|
||||
messagesService message.Service, // We need this to get the task tool messages
|
||||
focusedUIMessageId string,
|
||||
isSummary bool,
|
||||
width int,
|
||||
position int,
|
||||
) []uiMessage {
|
||||
@@ -168,6 +169,9 @@ func renderAssistantMessage(
|
||||
if content == "" {
|
||||
content = "*Finished without output*"
|
||||
}
|
||||
if isSummary {
|
||||
info = append(info, baseStyle.Width(width-1).Foreground(t.TextMuted()).Render(" (summary)"))
|
||||
}
|
||||
|
||||
content = renderMessage(content, false, true, width, info...)
|
||||
messages = append(messages, uiMessage{
|
||||
|
||||
@@ -331,30 +331,6 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
if payload.Done && payload.Type == agent.AgentEventTypeSummarize {
|
||||
a.isCompacting = false
|
||||
|
||||
if payload.SessionID != "" {
|
||||
// Switch to the new session
|
||||
return a, func() tea.Msg {
|
||||
sessions, err := a.app.Sessions.List(context.Background())
|
||||
if err != nil {
|
||||
return util.InfoMsg{
|
||||
Type: util.InfoTypeError,
|
||||
Msg: "Failed to list sessions: " + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range sessions {
|
||||
if s.ID == payload.SessionID {
|
||||
return dialog.SessionSelectedMsg{Session: s}
|
||||
}
|
||||
}
|
||||
|
||||
return util.InfoMsg{
|
||||
Type: util.InfoTypeError,
|
||||
Msg: "Failed to find new session",
|
||||
}
|
||||
}
|
||||
}
|
||||
return a, util.ReportInfo("Session summarization complete")
|
||||
} else if payload.Done && payload.Type == agent.AgentEventTypeResponse && a.selectedSession.ID != "" {
|
||||
model := a.app.CoderAgent.Model()
|
||||
|
||||
Reference in New Issue
Block a user