mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add history handling
This commit is contained in:
@@ -4,22 +4,31 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
"github.com/kardolus/chatgpt-cli/http"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
const (
|
||||
model = "gpt-3.5-turbo"
|
||||
role = "user"
|
||||
URL = "https://api.openai.com/v1/chat/completions"
|
||||
AssistantContent = "You are a helpful assistant."
|
||||
AssistantRole = "assistant"
|
||||
GPTModel = "gpt-3.5-turbo"
|
||||
SystemRole = "system"
|
||||
URL = "https://api.openai.com/v1/chat/completions"
|
||||
UserRole = "user"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
caller http.Caller
|
||||
caller http.Caller
|
||||
readWriter history.Store
|
||||
history []types.Message
|
||||
}
|
||||
|
||||
func New(caller http.Caller) *Client {
|
||||
return &Client{caller: caller}
|
||||
func New(caller http.Caller, rw history.Store) *Client {
|
||||
return &Client{
|
||||
caller: caller,
|
||||
readWriter: rw,
|
||||
}
|
||||
}
|
||||
|
||||
// Query sends a query to the API and returns the response as a string.
|
||||
@@ -29,7 +38,9 @@ func New(caller http.Caller) *Client {
|
||||
// call using the Post method. If the response is not empty, it decodes the
|
||||
// response JSON and returns the content of the first choice.
|
||||
func (c *Client) Query(input string) (string, error) {
|
||||
body, err := CreateBody(input, false)
|
||||
c.initHistory(input)
|
||||
|
||||
body, err := c.createBody(false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -52,6 +63,8 @@ func (c *Client) Query(input string) (string, error) {
|
||||
return "", errors.New("no responses returned")
|
||||
}
|
||||
|
||||
c.updateHistory(response.Choices[0].Message.Content)
|
||||
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
@@ -61,35 +74,59 @@ func (c *Client) Query(input string) (string, error) {
|
||||
// input and then makes an API call using the Post method. The actual
|
||||
// processing of the streamed response is done in the Post method.
|
||||
func (c *Client) Stream(input string) error {
|
||||
body, err := CreateBody(input, true)
|
||||
c.initHistory(input)
|
||||
|
||||
body, err := c.createBody(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.caller.Post(URL, body, true)
|
||||
result, err := c.caller.Post(URL, body, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.updateHistory(string(result))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateBody(query string, stream bool) ([]byte, error) {
|
||||
message := types.Message{
|
||||
Role: role,
|
||||
Content: query,
|
||||
}
|
||||
|
||||
func (c *Client) createBody(stream bool) ([]byte, error) {
|
||||
body := types.Request{
|
||||
Model: model,
|
||||
Messages: []types.Message{message},
|
||||
Model: GPTModel,
|
||||
Messages: c.history,
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return json.Marshal(body)
|
||||
}
|
||||
|
||||
func (c *Client) initHistory(query string) {
|
||||
message := types.Message{
|
||||
Role: UserRole,
|
||||
Content: query,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
c.history, _ = c.readWriter.Read()
|
||||
if len(c.history) == 0 {
|
||||
c.history = []types.Message{{
|
||||
Role: SystemRole,
|
||||
Content: AssistantContent,
|
||||
}}
|
||||
}
|
||||
|
||||
// TODO Write history specific tests
|
||||
// TODO Write delete-specific tests (on store)
|
||||
// TODO Test the string returned from Stream
|
||||
// TODO implement sliding window
|
||||
|
||||
c.history = append(c.history, message)
|
||||
}
|
||||
|
||||
func (c *Client) updateHistory(response string) {
|
||||
c.history = append(c.history, types.Message{
|
||||
Role: AssistantRole,
|
||||
Content: response,
|
||||
})
|
||||
_ = c.readWriter.Write(c.history)
|
||||
}
|
||||
|
||||
@@ -14,11 +14,13 @@ import (
|
||||
"github.com/sclevine/spec/report"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
|
||||
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
|
||||
//go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
|
||||
|
||||
var (
|
||||
mockCtrl *gomock.Controller
|
||||
mockCaller *MockCaller
|
||||
mockStore *MockStore
|
||||
subject *client.Client
|
||||
)
|
||||
|
||||
@@ -31,8 +33,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
RegisterTestingT(t)
|
||||
mockCtrl = gomock.NewController(t)
|
||||
mockCaller = NewMockCaller(mockCtrl)
|
||||
mockStore = NewMockStore(mockCtrl)
|
||||
|
||||
subject = client.New(mockCaller)
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
subject = client.New(mockCaller, mockStore)
|
||||
})
|
||||
|
||||
it.After(func() {
|
||||
@@ -43,12 +47,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
const query = "test query"
|
||||
|
||||
var (
|
||||
err error
|
||||
body []byte
|
||||
body []byte
|
||||
messages []types.Message
|
||||
err error
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
body, err = client.CreateBody(query, false)
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -56,14 +62,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
|
||||
|
||||
_, err = subject.Query(query)
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(errorMsg))
|
||||
})
|
||||
it("throws an error when the response is empty", func() {
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("empty response"))
|
||||
})
|
||||
@@ -71,7 +77,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
|
||||
})
|
||||
@@ -93,10 +99,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err.Error()).To(Equal("no responses returned"))
|
||||
})
|
||||
it("parses a valid http response as expected", func() {
|
||||
const answer = "content"
|
||||
|
||||
choice := types.Choice{
|
||||
Message: types.Message{
|
||||
Role: "role",
|
||||
Content: "content",
|
||||
Content: answer,
|
||||
},
|
||||
FinishReason: "",
|
||||
Index: 0,
|
||||
@@ -113,9 +121,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
|
||||
|
||||
mockStore.EXPECT().Write(append(messages, types.Message{
|
||||
Role: client.AssistantRole,
|
||||
Content: answer,
|
||||
}))
|
||||
|
||||
result, err := subject.Query(query)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal("content"))
|
||||
Expect(result).To(Equal(answer))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func createBody(messages []types.Message) ([]byte, error) {
|
||||
req := types.Request{
|
||||
Model: client.GPTModel,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func createMessages(history []types.Message, query string) []types.Message {
|
||||
var messages []types.Message
|
||||
|
||||
if len(history) == 0 {
|
||||
messages = append(messages, types.Message{
|
||||
Role: client.SystemRole,
|
||||
Content: client.AssistantContent,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, types.Message{
|
||||
Role: client.UserRole,
|
||||
Content: query,
|
||||
})
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
78
client/iomocks_test.go
Normal file
78
client/iomocks_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store)
|
||||
|
||||
// Package client_test is a generated GoMock package.
|
||||
package client_test
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
types "github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
// MockStore is a mock of Store interface.
|
||||
type MockStore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStoreMockRecorder
|
||||
}
|
||||
|
||||
// MockStoreMockRecorder is the mock recorder for MockStore.
|
||||
type MockStoreMockRecorder struct {
|
||||
mock *MockStore
|
||||
}
|
||||
|
||||
// NewMockStore creates a new mock instance.
|
||||
func NewMockStore(ctrl *gomock.Controller) *MockStore {
|
||||
mock := &MockStore{ctrl: ctrl}
|
||||
mock.recorder = &MockStoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockStore) Delete() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockStoreMockRecorder) Delete() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete))
|
||||
}
|
||||
|
||||
// Read mocks base method.
|
||||
func (m *MockStore) Read() ([]types.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Read")
|
||||
ret0, _ := ret[0].([]types.Message)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read.
|
||||
func (mr *MockStoreMockRecorder) Read() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read))
|
||||
}
|
||||
|
||||
// Write mocks base method.
|
||||
func (m *MockStore) Write(arg0 []types.Message) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write.
|
||||
func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0)
|
||||
}
|
||||
Reference in New Issue
Block a user