Improve summary to keep context (#159)

* improve summary to keep context

* improve loop

* remove debug msg
This commit is contained in:
Kujtim Hoxha
2025-05-15 15:59:18 +02:00
committed by GitHub
parent 4f0c1c633a
commit 3e424754b4
13 changed files with 107 additions and 56 deletions

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
package db

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
// source: files.sql
package db

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
// source: messages.sql
package db

View File

@@ -0,0 +1,9 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE sessions ADD COLUMN summary_message_id TEXT;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE sessions DROP COLUMN summary_message_id;
-- +goose StatementEnd

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
package db
@@ -39,4 +39,5 @@ type Session struct {
Cost float64 `json:"cost"`
UpdatedAt int64 `json:"updated_at"`
CreatedAt int64 `json:"created_at"`
SummaryMessageID sql.NullString `json:"summary_message_id"`
}

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
package db

View File

@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.27.0
// sqlc v1.29.0
// source: sessions.sql
package db
@@ -19,6 +19,7 @@ INSERT INTO sessions (
prompt_tokens,
completion_tokens,
cost,
summary_message_id,
updated_at,
created_at
) VALUES (
@@ -29,9 +30,10 @@ INSERT INTO sessions (
?,
?,
?,
null,
strftime('%s', 'now'),
strftime('%s', 'now')
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
`
type CreateSessionParams struct {
@@ -65,6 +67,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
&i.Cost,
&i.UpdatedAt,
&i.CreatedAt,
&i.SummaryMessageID,
)
return i, err
}
@@ -80,7 +83,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
}
const getSessionByID = `-- name: GetSessionByID :one
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
FROM sessions
WHERE id = ? LIMIT 1
`
@@ -98,12 +101,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
&i.Cost,
&i.UpdatedAt,
&i.CreatedAt,
&i.SummaryMessageID,
)
return i, err
}
const listSessions = `-- name: ListSessions :many
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
FROM sessions
WHERE parent_session_id is NULL
ORDER BY created_at DESC
@@ -128,6 +132,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
&i.Cost,
&i.UpdatedAt,
&i.CreatedAt,
&i.SummaryMessageID,
); err != nil {
return nil, err
}
@@ -148,17 +153,19 @@ SET
title = ?,
prompt_tokens = ?,
completion_tokens = ?,
summary_message_id = ?,
cost = ?
WHERE id = ?
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
`
type UpdateSessionParams struct {
Title string `json:"title"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
Cost float64 `json:"cost"`
ID string `json:"id"`
Title string `json:"title"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
SummaryMessageID sql.NullString `json:"summary_message_id"`
Cost float64 `json:"cost"`
ID string `json:"id"`
}
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
@@ -166,6 +173,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
arg.Title,
arg.PromptTokens,
arg.CompletionTokens,
arg.SummaryMessageID,
arg.Cost,
arg.ID,
)
@@ -180,6 +188,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
&i.Cost,
&i.UpdatedAt,
&i.CreatedAt,
&i.SummaryMessageID,
)
return i, err
}

View File

@@ -7,6 +7,7 @@ INSERT INTO sessions (
prompt_tokens,
completion_tokens,
cost,
summary_message_id,
updated_at,
created_at
) VALUES (
@@ -17,6 +18,7 @@ INSERT INTO sessions (
?,
?,
?,
null,
strftime('%s', 'now'),
strftime('%s', 'now')
) RETURNING *;
@@ -38,6 +40,7 @@ SET
title = ?,
prompt_tokens = ?,
completion_tokens = ?,
summary_message_id = ?,
cost = ?
WHERE id = ?
RETURNING *;

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"strings"
"sync"
"time"
"github.com/opencode-ai/opencode/internal/config"
"github.com/opencode-ai/opencode/internal/llm/models"
@@ -245,6 +246,23 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
}
}()
}
session, err := a.sessions.Get(ctx, sessionID)
if err != nil {
return a.err(fmt.Errorf("failed to get session: %w", err))
}
if session.SummaryMessageID != "" {
summaryMsgInex := -1
for i, msg := range msgs {
if msg.ID == session.SummaryMessageID {
summaryMsgInex = i
break
}
}
if summaryMsgInex != -1 {
msgs = msgs[summaryMsgInex:]
msgs[0].Role = message.User
}
}
userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
if err != nil {
@@ -614,22 +632,16 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
a.Publish(pubsub.CreatedEvent, event)
return
}
// Create a new session with the summary
newSession, err := a.sessions.Create(summarizeCtx, oldSession.Title+" - Continuation")
if err != nil {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("failed to create new session: %w", err),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
return
}
// Create a message in the new session with the summary
_, err = a.messages.Create(summarizeCtx, newSession.ID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{message.TextContent{Text: summary}},
msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{
message.TextContent{Text: summary},
message.Finish{
Reason: message.FinishReasonEndTurn,
Time: time.Now().Unix(),
},
},
Model: a.summarizeProvider.Model().ID,
})
if err != nil {
@@ -642,9 +654,29 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
a.Publish(pubsub.CreatedEvent, event)
return
}
oldSession.SummaryMessageID = msg.ID
oldSession.CompletionTokens = response.Usage.OutputTokens
oldSession.PromptTokens = 0
model := a.summarizeProvider.Model()
usage := response.Usage
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
oldSession.Cost += cost
_, err = a.sessions.Save(summarizeCtx, oldSession)
if err != nil {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("failed to save session: %w", err),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
}
event = AgentEvent{
Type: AgentEventTypeSummarize,
SessionID: newSession.ID,
SessionID: oldSession.ID,
Progress: "Summary complete",
Done: true,
}

View File

@@ -16,6 +16,7 @@ type Session struct {
MessageCount int64
PromptTokens int64
CompletionTokens int64
SummaryMessageID string
Cost float64
CreatedAt int64
UpdatedAt int64
@@ -105,7 +106,11 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
Title: session.Title,
PromptTokens: session.PromptTokens,
CompletionTokens: session.CompletionTokens,
Cost: session.Cost,
SummaryMessageID: sql.NullString{
String: session.SummaryMessageID,
Valid: session.SummaryMessageID != "",
},
Cost: session.Cost,
})
if err != nil {
return Session{}, err
@@ -135,6 +140,7 @@ func (s service) fromDBItem(item db.Session) Session {
MessageCount: item.MessageCount,
PromptTokens: item.PromptTokens,
CompletionTokens: item.CompletionTokens,
SummaryMessageID: item.SummaryMessageID.String,
Cost: item.Cost,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,

View File

@@ -99,6 +99,14 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case renderFinishedMsg:
m.rendering = false
m.viewport.GotoBottom()
case pubsub.Event[session.Session]:
if msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.session.ID {
m.session = msg.Payload
if m.session.SummaryMessageID == m.currentMsgID {
delete(m.cachedContent, m.currentMsgID)
m.renderView()
}
}
case pubsub.Event[message.Message]:
needsRerender := false
if msg.Type == pubsub.CreatedEvent {
@@ -208,12 +216,15 @@ func (m *messagesCmp) renderView() {
m.uiMessages = append(m.uiMessages, cache.content...)
continue
}
isSummary := m.session.SummaryMessageID == msg.ID
assistantMessages := renderAssistantMessage(
msg,
inx,
m.messages,
m.app.Messages,
m.currentMsgID,
isSummary,
m.width,
pos,
)

View File

@@ -120,6 +120,7 @@ func renderAssistantMessage(
allMessages []message.Message, // we need this to get tool results and the user message
messagesService message.Service, // We need this to get the task tool messages
focusedUIMessageId string,
isSummary bool,
width int,
position int,
) []uiMessage {
@@ -168,6 +169,9 @@ func renderAssistantMessage(
if content == "" {
content = "*Finished without output*"
}
if isSummary {
info = append(info, baseStyle.Width(width-1).Foreground(t.TextMuted()).Render(" (summary)"))
}
content = renderMessage(content, false, true, width, info...)
messages = append(messages, uiMessage{

View File

@@ -331,30 +331,6 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if payload.Done && payload.Type == agent.AgentEventTypeSummarize {
a.isCompacting = false
if payload.SessionID != "" {
// Switch to the new session
return a, func() tea.Msg {
sessions, err := a.app.Sessions.List(context.Background())
if err != nil {
return util.InfoMsg{
Type: util.InfoTypeError,
Msg: "Failed to list sessions: " + err.Error(),
}
}
for _, s := range sessions {
if s.ID == payload.SessionID {
return dialog.SessionSelectedMsg{Session: s}
}
}
return util.InfoMsg{
Type: util.InfoTypeError,
Msg: "Failed to find new session",
}
}
}
return a, util.ReportInfo("Session summarization complete")
} else if payload.Done && payload.Type == agent.AgentEventTypeResponse && a.selectedSession.ID != "" {
model := a.app.CoderAgent.Model()