mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Implement a rolling window for the history
This commit is contained in:
@@ -7,27 +7,41 @@ import (
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
"github.com/kardolus/chatgpt-cli/http"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
AssistantContent = "You are a helpful assistant."
|
||||
AssistantRole = "assistant"
|
||||
GPTModel = "gpt-3.5-turbo"
|
||||
SystemRole = "system"
|
||||
URL = "https://api.openai.com/v1/chat/completions"
|
||||
UserRole = "user"
|
||||
AssistantContent = "You are a helpful assistant."
|
||||
AssistantRole = "assistant"
|
||||
GPTModel = "gpt-3.5-turbo"
|
||||
MaxTokenBufferPercentage = 20
|
||||
MaxTokenSize = 4096
|
||||
SystemRole = "system"
|
||||
URL = "https://api.openai.com/v1/chat/completions"
|
||||
UserRole = "user"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
caller http.Caller
|
||||
readWriter history.Store
|
||||
history []types.Message
|
||||
capacity int
|
||||
}
|
||||
|
||||
func New(caller http.Caller, rw history.Store) *Client {
|
||||
func New(caller http.Caller, rw history.Store, capacity int) *Client {
|
||||
return &Client{
|
||||
caller: caller,
|
||||
readWriter: rw,
|
||||
capacity: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDefault(caller http.Caller, rw history.Store) *Client {
|
||||
return &Client{
|
||||
caller: caller,
|
||||
readWriter: rw,
|
||||
capacity: MaxTokenSize,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,9 +129,31 @@ func (c *Client) initHistory(query string) {
|
||||
}}
|
||||
}
|
||||
|
||||
// TODO implement sliding window
|
||||
|
||||
c.history = append(c.history, message)
|
||||
c.truncateHistory()
|
||||
}
|
||||
|
||||
func (c *Client) truncateHistory() {
|
||||
tokens, rolling := countTokens(c.history)
|
||||
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)
|
||||
|
||||
if tokens <= effectiveTokenSize {
|
||||
return
|
||||
}
|
||||
|
||||
var index int
|
||||
var total int
|
||||
diff := tokens - effectiveTokenSize
|
||||
|
||||
for i := 1; i < len(rolling); i++ {
|
||||
total += rolling[i]
|
||||
if total > diff {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.history = append(c.history[:1], c.history[index+1:]...)
|
||||
}
|
||||
|
||||
func (c *Client) updateHistory(response string) {
|
||||
@@ -127,3 +163,32 @@ func (c *Client) updateHistory(response string) {
|
||||
})
|
||||
_ = c.readWriter.Write(c.history)
|
||||
}
|
||||
|
||||
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {
|
||||
adjustedPercentage := 100 - bufferPercentage
|
||||
effectiveTokenSize := (maxTokenSize * adjustedPercentage) / 100
|
||||
return effectiveTokenSize
|
||||
}
|
||||
|
||||
func countTokens(messages []types.Message) (int, []int) {
|
||||
var result int
|
||||
var rolling []int
|
||||
|
||||
for _, message := range messages {
|
||||
charCount, wordCount := 0, 0
|
||||
words := strings.Fields(message.Content)
|
||||
wordCount += len(words)
|
||||
|
||||
for _, word := range words {
|
||||
charCount += utf8.RuneCountInString(word)
|
||||
}
|
||||
|
||||
// This is a simple approximation; actual token count may differ.
|
||||
// You can adjust this based on your language and the specific tokenizer used by the model.
|
||||
tokenCountForMessage := (charCount + wordCount) / 2
|
||||
result += tokenCountForMessage
|
||||
rolling = append(rolling, tokenCountForMessage)
|
||||
}
|
||||
|
||||
return result, rolling
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
mockCtrl = gomock.NewController(t)
|
||||
mockCaller = NewMockCaller(mockCtrl)
|
||||
mockStore = NewMockStore(mockCtrl)
|
||||
subject = client.New(mockCaller, mockStore)
|
||||
subject = client.New(mockCaller, mockStore, 50)
|
||||
})
|
||||
|
||||
it.After(func() {
|
||||
@@ -129,8 +129,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.URL, expectedBody, false).Return(respBytes, nil)
|
||||
|
||||
messages = createMessages(history, query)
|
||||
mockStore.EXPECT().Write(append(messages, types.Message{
|
||||
var request types.Request
|
||||
err = json.Unmarshal(expectedBody, &request)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockStore.EXPECT().Write(append(request.Messages, types.Message{
|
||||
Role: client.AssistantRole,
|
||||
Content: answer,
|
||||
}))
|
||||
@@ -162,6 +165,48 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
body, err = createBody(messages, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(history, body)
|
||||
})
|
||||
it("truncates the history as expected", func() {
|
||||
history := []types.Message{
|
||||
{
|
||||
Role: client.SystemRole,
|
||||
Content: client.AssistantContent,
|
||||
},
|
||||
{
|
||||
Role: client.UserRole,
|
||||
Content: "question 1",
|
||||
},
|
||||
{
|
||||
Role: client.AssistantRole,
|
||||
Content: "answer 1",
|
||||
},
|
||||
{
|
||||
Role: client.UserRole,
|
||||
Content: "question 2",
|
||||
},
|
||||
{
|
||||
Role: client.AssistantRole,
|
||||
Content: "answer 2",
|
||||
},
|
||||
{
|
||||
Role: client.UserRole,
|
||||
Content: "question 3",
|
||||
},
|
||||
{
|
||||
Role: client.AssistantRole,
|
||||
Content: "answer 3",
|
||||
},
|
||||
}
|
||||
|
||||
messages = createMessages(history, query)
|
||||
|
||||
// messages get truncated. Index 1+2 are cut out
|
||||
messages = append(messages[:1], messages[3:]...)
|
||||
|
||||
body, err = createBody(messages, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(history, body)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user