Implement a rolling window for the history

This commit is contained in:
kardolus
2023-05-03 13:45:49 -04:00
parent 7d74bcd9a4
commit 75bd5b9a2c
4 changed files with 130 additions and 13 deletions

View File

@@ -7,27 +7,41 @@ import (
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types"
"strings"
"unicode/utf8"
)
const (
AssistantContent = "You are a helpful assistant."
AssistantRole = "assistant"
GPTModel = "gpt-3.5-turbo"
SystemRole = "system"
URL = "https://api.openai.com/v1/chat/completions"
UserRole = "user"
AssistantContent = "You are a helpful assistant."
AssistantRole = "assistant"
GPTModel = "gpt-3.5-turbo"
MaxTokenBufferPercentage = 20
MaxTokenSize = 4096
SystemRole = "system"
URL = "https://api.openai.com/v1/chat/completions"
UserRole = "user"
)
type Client struct {
caller http.Caller
readWriter history.Store
history []types.Message
capacity int
}
func New(caller http.Caller, rw history.Store) *Client {
func New(caller http.Caller, rw history.Store, capacity int) *Client {
return &Client{
caller: caller,
readWriter: rw,
capacity: capacity,
}
}
func NewDefault(caller http.Caller, rw history.Store) *Client {
return &Client{
caller: caller,
readWriter: rw,
capacity: MaxTokenSize,
}
}
@@ -115,9 +129,31 @@ func (c *Client) initHistory(query string) {
}}
}
// TODO implement sliding window
c.history = append(c.history, message)
c.truncateHistory()
}
func (c *Client) truncateHistory() {
tokens, rolling := countTokens(c.history)
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)
if tokens <= effectiveTokenSize {
return
}
var index int
var total int
diff := tokens - effectiveTokenSize
for i := 1; i < len(rolling); i++ {
total += rolling[i]
if total > diff {
index = i
break
}
}
c.history = append(c.history[:1], c.history[index+1:]...)
}
func (c *Client) updateHistory(response string) {
@@ -127,3 +163,32 @@ func (c *Client) updateHistory(response string) {
})
_ = c.readWriter.Write(c.history)
}
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {
adjustedPercentage := 100 - bufferPercentage
effectiveTokenSize := (maxTokenSize * adjustedPercentage) / 100
return effectiveTokenSize
}
func countTokens(messages []types.Message) (int, []int) {
var result int
var rolling []int
for _, message := range messages {
charCount, wordCount := 0, 0
words := strings.Fields(message.Content)
wordCount += len(words)
for _, word := range words {
charCount += utf8.RuneCountInString(word)
}
// This is a simple approximation; actual token count may differ.
// You can adjust this based on your language and the specific tokenizer used by the model.
tokenCountForMessage := (charCount + wordCount) / 2
result += tokenCountForMessage
rolling = append(rolling, tokenCountForMessage)
}
return result, rolling
}

View File

@@ -36,7 +36,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockCtrl = gomock.NewController(t)
mockCaller = NewMockCaller(mockCtrl)
mockStore = NewMockStore(mockCtrl)
subject = client.New(mockCaller, mockStore)
subject = client.New(mockCaller, mockStore, 50)
})
it.After(func() {
@@ -129,8 +129,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, expectedBody, false).Return(respBytes, nil)
messages = createMessages(history, query)
mockStore.EXPECT().Write(append(messages, types.Message{
var request types.Request
err = json.Unmarshal(expectedBody, &request)
Expect(err).NotTo(HaveOccurred())
mockStore.EXPECT().Write(append(request.Messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
@@ -162,6 +165,48 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
})
it("truncates the history as expected", func() {
history := []types.Message{
{
Role: client.SystemRole,
Content: client.AssistantContent,
},
{
Role: client.UserRole,
Content: "question 1",
},
{
Role: client.AssistantRole,
Content: "answer 1",
},
{
Role: client.UserRole,
Content: "question 2",
},
{
Role: client.AssistantRole,
Content: "answer 2",
},
{
Role: client.UserRole,
Content: "question 3",
},
{
Role: client.AssistantRole,
Content: "answer 3",
},
}
messages = createMessages(history, query)
// messages get truncated. Index 1+2 are cut out
messages = append(messages[:1], messages[3:]...)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
})
})