Add custom context parsing

This commit is contained in:
kardolus
2023-05-03 18:38:12 -04:00
parent dd791fbeb9
commit c5426d68af
4 changed files with 96 additions and 25 deletions

View File

@@ -23,9 +23,9 @@ const (
)
type Client struct {
History []types.Message
caller http.Caller
readWriter history.Store
history []types.Message
capacity int
}
@@ -52,7 +52,8 @@ func NewDefault(caller http.Caller, rw history.Store) *Client {
// call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) {
c.initHistory(input)
c.initHistory()
c.addQuery(input)
body, err := c.createBody(false)
if err != nil {
@@ -88,7 +89,8 @@ func (c *Client) Query(input string) (string, error) {
// input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error {
c.initHistory(input)
c.initHistory()
c.addQuery(input)
body, err := c.createBody(true)
if err != nil {
@@ -105,36 +107,55 @@ func (c *Client) Stream(input string) error {
return nil
}
// ProvideContext adds custom context to the client's history by converting the
// provided string into a series of messages. This allows the ChatGPT API to have
// prior knowledge of the provided context when generating responses.
//
// The context string should contain the text you want to provide as context,
// and the method will split it into messages, preserving punctuation and special
// characters.
func (c *Client) ProvideContext(context string) {
c.initHistory()
messages := createMessagesFromString(context)
c.History = append(c.History, messages...)
}
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.Request{
Model: GPTModel,
Messages: c.history,
Messages: c.History,
Stream: stream,
}
return json.Marshal(body)
}
func (c *Client) initHistory(query string) {
func (c *Client) initHistory() {
if len(c.History) != 0 {
return
}
c.History, _ = c.readWriter.Read()
if len(c.History) == 0 {
c.History = []types.Message{{
Role: SystemRole,
Content: AssistantContent,
}}
}
}
func (c *Client) addQuery(query string) {
message := types.Message{
Role: UserRole,
Content: query,
}
c.history, _ = c.readWriter.Read()
if len(c.history) == 0 {
c.history = []types.Message{{
Role: SystemRole,
Content: AssistantContent,
}}
}
c.history = append(c.history, message)
c.History = append(c.History, message)
c.truncateHistory()
}
func (c *Client) truncateHistory() {
tokens, rolling := countTokens(c.history)
tokens, rolling := countTokens(c.History)
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)
if tokens <= effectiveTokenSize {
@@ -153,15 +174,15 @@ func (c *Client) truncateHistory() {
}
}
c.history = append(c.history[:1], c.history[index+1:]...)
c.History = append(c.History[:1], c.History[index+1:]...)
}
func (c *Client) updateHistory(response string) {
c.history = append(c.history, types.Message{
c.History = append(c.History, types.Message{
Role: AssistantRole,
Content: response,
})
_ = c.readWriter.Write(c.history)
_ = c.readWriter.Write(c.History)
}
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {
@@ -192,3 +213,25 @@ func countTokens(messages []types.Message) (int, []int) {
return result, rolling
}
func createMessagesFromString(input string) []types.Message {
words := strings.Fields(input)
var messages []types.Message
for i := 0; i < len(words); i += 100 {
end := i + 100
if end > len(words) {
end = len(words)
}
content := strings.Join(words[i:end], " ")
message := types.Message{
Role: UserRole,
Content: content,
}
messages = append(messages, message)
}
return messages
}

View File

@@ -278,6 +278,23 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
})
})
when("ProvideContext()", func() {
it("updates the history with the provided context", func() {
context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
subject.ProvideContext(context)
Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context
systemMessage := subject.History[0]
Expect(systemMessage.Role).To(Equal(client.SystemRole))
Expect(systemMessage.Content).To(Equal("You are a helpful assistant."))
contextMessage := subject.History[1]
Expect(contextMessage.Role).To(Equal(client.UserRole))
Expect(contextMessage.Content).To(Equal(context))
})
})
}
func createBody(messages []types.Message, stream bool) ([]byte, error) {