Add history handling

This commit is contained in:
kardolus
2023-05-03 01:03:09 -04:00
parent fd892219da
commit 5be0af6cb0
9 changed files with 330 additions and 42 deletions

View File

@@ -4,22 +4,31 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types"
)
const (
model = "gpt-3.5-turbo"
role = "user"
URL = "https://api.openai.com/v1/chat/completions"
AssistantContent = "You are a helpful assistant."
AssistantRole = "assistant"
GPTModel = "gpt-3.5-turbo"
SystemRole = "system"
URL = "https://api.openai.com/v1/chat/completions"
UserRole = "user"
)
type Client struct {
caller http.Caller
caller http.Caller
readWriter history.Store
history []types.Message
}
func New(caller http.Caller) *Client {
return &Client{caller: caller}
func New(caller http.Caller, rw history.Store) *Client {
return &Client{
caller: caller,
readWriter: rw,
}
}
// Query sends a query to the API and returns the response as a string.
@@ -29,7 +38,9 @@ func New(caller http.Caller) *Client {
// call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) {
body, err := CreateBody(input, false)
c.initHistory(input)
body, err := c.createBody(false)
if err != nil {
return "", err
}
@@ -52,6 +63,8 @@ func (c *Client) Query(input string) (string, error) {
return "", errors.New("no responses returned")
}
c.updateHistory(response.Choices[0].Message.Content)
return response.Choices[0].Message.Content, nil
}
@@ -61,35 +74,59 @@ func (c *Client) Query(input string) (string, error) {
// input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error {
body, err := CreateBody(input, true)
c.initHistory(input)
body, err := c.createBody(true)
if err != nil {
return err
}
_, err = c.caller.Post(URL, body, true)
result, err := c.caller.Post(URL, body, true)
if err != nil {
return err
}
c.updateHistory(string(result))
return nil
}
func CreateBody(query string, stream bool) ([]byte, error) {
message := types.Message{
Role: role,
Content: query,
}
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.Request{
Model: model,
Messages: []types.Message{message},
Model: GPTModel,
Messages: c.history,
Stream: stream,
}
result, err := json.Marshal(body)
if err != nil {
return nil, err
return json.Marshal(body)
}
func (c *Client) initHistory(query string) {
message := types.Message{
Role: UserRole,
Content: query,
}
return result, nil
c.history, _ = c.readWriter.Read()
if len(c.history) == 0 {
c.history = []types.Message{{
Role: SystemRole,
Content: AssistantContent,
}}
}
// TODO Write history specific tests
// TODO Write delete-specific tests (on store)
// TODO Test the string returned from Stream
// TODO implement sliding window
c.history = append(c.history, message)
}
func (c *Client) updateHistory(response string) {
c.history = append(c.history, types.Message{
Role: AssistantRole,
Content: response,
})
_ = c.readWriter.Write(c.history)
}

View File

@@ -14,11 +14,13 @@ import (
"github.com/sclevine/spec/report"
)
//go:generate mockgen -destination=mocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
//go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
var (
mockCtrl *gomock.Controller
mockCaller *MockCaller
mockStore *MockStore
subject *client.Client
)
@@ -31,8 +33,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
RegisterTestingT(t)
mockCtrl = gomock.NewController(t)
mockCaller = NewMockCaller(mockCtrl)
mockStore = NewMockStore(mockCtrl)
subject = client.New(mockCaller)
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
subject = client.New(mockCaller, mockStore)
})
it.After(func() {
@@ -43,12 +47,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
const query = "test query"
var (
err error
body []byte
body []byte
messages []types.Message
err error
)
it.Before(func() {
body, err = client.CreateBody(query, false)
messages = createMessages(nil, query)
body, err = createBody(messages)
Expect(err).NotTo(HaveOccurred())
})
@@ -56,14 +62,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
errorMsg := "error message"
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
_, err = subject.Query(query)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
_, err = subject.Query(query)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("empty response"))
})
@@ -71,7 +77,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
_, err = subject.Query(query)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
})
@@ -93,10 +99,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).To(Equal("no responses returned"))
})
it("parses a valid http response as expected", func() {
const answer = "content"
choice := types.Choice{
Message: types.Message{
Role: "role",
Content: "content",
Content: answer,
},
FinishReason: "",
Index: 0,
@@ -113,9 +121,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
mockStore.EXPECT().Write(append(messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
result, err := subject.Query(query)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("content"))
Expect(result).To(Equal(answer))
})
})
}
func createBody(messages []types.Message) ([]byte, error) {
req := types.Request{
Model: client.GPTModel,
Messages: messages,
Stream: false,
}
return json.Marshal(req)
}
func createMessages(history []types.Message, query string) []types.Message {
var messages []types.Message
if len(history) == 0 {
messages = append(messages, types.Message{
Role: client.SystemRole,
Content: client.AssistantContent,
})
}
messages = append(messages, types.Message{
Role: client.UserRole,
Content: query,
})
return messages
}

78
client/iomocks_test.go Normal file
View File

@@ -0,0 +1,78 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store)
// Package client_test is a generated GoMock package.
package client_test
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
types "github.com/kardolus/chatgpt-cli/types"
)
// MockStore is a mock of Store interface.
type MockStore struct {
ctrl *gomock.Controller
recorder *MockStoreMockRecorder
}
// MockStoreMockRecorder is the mock recorder for MockStore.
type MockStoreMockRecorder struct {
mock *MockStore
}
// NewMockStore creates a new mock instance.
func NewMockStore(ctrl *gomock.Controller) *MockStore {
mock := &MockStore{ctrl: ctrl}
mock.recorder = &MockStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockStore) Delete() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete")
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockStoreMockRecorder) Delete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete))
}
// Read mocks base method.
func (m *MockStore) Read() ([]types.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read")
ret0, _ := ret[0].([]types.Message)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockStoreMockRecorder) Read() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read))
}
// Write mocks base method.
func (m *MockStore) Write(arg0 []types.Message) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Write indicates an expected call of Write.
func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0)
}