Add the --set-model flag

This commit is contained in:
kardolus
2023-06-08 11:14:51 -04:00
parent b0708c5bfa
commit 1ea1b684fb
13 changed files with 621 additions and 134 deletions

View File

@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types"
@@ -26,20 +28,28 @@ const (
)
type Client struct {
History []types.Message
caller http.Caller
capacity int
model string
readWriter history.Store
History []types.Message
Model string
caller http.Caller
capacity int
historyStore history.HistoryStore
}
func New(caller http.Caller, rw history.Store) *Client {
return &Client{
caller: caller,
readWriter: rw,
capacity: MaxTokenSize,
model: DefaultGPTModel,
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client {
result := &Client{
caller: caller,
historyStore: hs,
capacity: MaxTokenSize,
}
// do not error out when the config cannot be read
result.Model, _ = configmanager.New(cs).ReadModel()
if result.Model == "" {
result.Model = DefaultGPTModel
}
return result
}
func (c *Client) WithCapacity(capacity int) *Client {
@@ -48,7 +58,7 @@ func (c *Client) WithCapacity(capacity int) *Client {
}
func (c *Client) WithModel(model string) *Client {
c.model = model
c.Model = model
return c
}
@@ -73,7 +83,7 @@ func (c *Client) ListModels() ([]string, error) {
for _, model := range response.Data {
if strings.HasPrefix(model.Id, gptPrefix) {
if model.Id != DefaultGPTModel {
if model.Id != c.Model {
result = append(result, fmt.Sprintf("- %s", model.Id))
continue
}
@@ -156,7 +166,7 @@ func (c *Client) Stream(input string) error {
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.CompletionsRequest{
Messages: c.History,
Model: c.model,
Model: c.Model,
Stream: stream,
}
@@ -168,7 +178,7 @@ func (c *Client) initHistory() {
return
}
c.History, _ = c.readWriter.Read()
c.History, _ = c.historyStore.Read()
if len(c.History) == 0 {
c.History = []types.Message{{
Role: SystemRole,
@@ -232,7 +242,7 @@ func (c *Client) updateHistory(response string) {
Role: AssistantRole,
Content: response,
})
_ = c.readWriter.Write(c.History)
_ = c.historyStore.Write(c.History)
}
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {

View File

@@ -16,13 +16,15 @@ import (
)
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
//go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
//go:generate mockgen -destination=historymocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history HistoryStore
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore
var (
mockCtrl *gomock.Controller
mockCaller *MockCaller
mockStore *MockStore
subject *client.Client
mockCtrl *gomock.Controller
mockCaller *MockCaller
mockHistoryStore *MockHistoryStore
mockConfigStore *MockConfigStore
factory *clientFactory
)
func TestUnitClient(t *testing.T) {
@@ -36,8 +38,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
RegisterTestingT(t)
mockCtrl = gomock.NewController(t)
mockCaller = NewMockCaller(mockCtrl)
mockStore = NewMockStore(mockCtrl)
subject = client.New(mockCaller, mockStore).WithCapacity(50)
mockHistoryStore = NewMockHistoryStore(mockCtrl)
mockConfigStore = NewMockConfigStore(mockCtrl)
factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore)
})
it.After(func() {
@@ -51,63 +54,75 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
err error
)
it.Before(func() {
messages = createMessages(nil, query)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
})
type TestCase struct {
description string
setupPostReturn func() ([]byte, error)
postError error
expectedError string
}
it("throws an error when the http callout fails", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
tests := []TestCase{
{
description: "throws an error when the http callout fails",
setupPostReturn: func() ([]byte, error) { return nil, nil },
postError: errors.New("error message"),
expectedError: "error message",
},
{
description: "throws an error when the response is empty",
setupPostReturn: func() ([]byte, error) { return nil, nil },
postError: nil,
expectedError: "empty response",
},
{
description: "throws an error when the response is a malformed json",
setupPostReturn: func() ([]byte, error) {
malformed := `{"invalid":"json"` // missing closing brace
return []byte(malformed), nil
},
postError: nil,
expectedError: "failed to decode response:",
},
{
description: "throws an error when the response is missing Choices",
setupPostReturn: func() ([]byte, error) {
response := &types.CompletionsResponse{
ID: "id",
Object: "object",
Created: 0,
Model: "model",
Choices: []types.Choice{},
}
errorMsg := "error message"
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, errors.New(errorMsg))
respBytes, err := json.Marshal(response)
return respBytes, err
},
postError: nil,
expectedError: "no responses returned",
},
}
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, nil)
for _, tt := range tests {
it(tt.description, func() {
factory.withoutHistory()
subject := factory.buildClientWithoutConfig()
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("empty response"))
})
it("throws an error when the response is a malformed json", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, false)
Expect(err).NotTo(HaveOccurred())
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return([]byte(malformed), nil)
respBytes, err := tt.setupPostReturn()
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, tt.postError)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
})
it("throws an error when the response is missing Choices", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(tt.expectedError))
})
}
response := &types.CompletionsResponse{
ID: "id",
Object: "object",
Created: 0,
Model: "model",
Choices: []types.Choice{},
}
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("no responses returned"))
})
when("a valid http response is received", func() {
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
mockStore.EXPECT().Read().Return(history, nil).Times(1)
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
const answer = "content"
choice := types.Choice{
@@ -122,7 +137,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
ID: "id",
Object: "object",
Created: 0,
Model: client.DefaultGPTModel,
Model: subject.Model,
Choices: []types.Choice{choice},
}
@@ -134,7 +149,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
err = json.Unmarshal(expectedBody, &request)
Expect(err).NotTo(HaveOccurred())
mockStore.EXPECT().Write(append(request.Messages, types.Message{
mockHistoryStore.EXPECT().Write(append(request.Messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
@@ -144,8 +159,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(result).To(Equal(answer))
}
it("returns the expected result for an empty history", func() {
testValidHTTPResponse(nil, body)
it("uses the model specified by the WithModel method instead of the default model", func() {
const model = "overwritten"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithoutConfig().WithModel(model)
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("uses the model specified by the configuration instead of the default model", func() {
const model = "overwritten"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithConfig(types.Config{
Model: model,
})
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("when WithModel is used and a configuration is present, WithModel takes precedence", func() {
const model = "with-model"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithConfig(types.Config{
Model: "config-model",
}).WithModel(model)
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("returns the expected result for a non-empty history", func() {
history := []types.Message{
@@ -163,10 +212,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
},
}
messages = createMessages(history, query)
body, err = createBody(messages, false)
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
body, err = createBody(messages, subject.Model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
testValidHTTPResponse(subject, history, body)
})
it("truncates the history as expected", func() {
history := []types.Message{
@@ -202,13 +254,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
messages = createMessages(history, query)
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
// messages get truncated. Index 1+2 are cut out
messages = append(messages[:1], messages[3:]...)
body, err = createBody(messages, false)
body, err = createBody(messages, subject.Model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
testValidHTTPResponse(subject, history, body)
})
})
})
@@ -219,14 +274,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
err error
)
it.Before(func() {
messages = createMessages(nil, query)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
})
it("throws an error when the http callout fails", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
factory.withoutHistory()
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
errorMsg := "error message"
mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg))
@@ -238,13 +292,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
when("a valid http response is received", func() {
const answer = "answer"
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
mockStore.EXPECT().Read().Return(history, nil).Times(1)
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil)
messages = createMessages(history, query)
mockStore.EXPECT().Write(append(messages, types.Message{
mockHistoryStore.EXPECT().Write(append(messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
@@ -254,7 +311,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
}
it("returns the expected result for an empty history", func() {
testValidHTTPResponse(nil, body)
factory.withHistory(nil)
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("returns the expected result for a non-empty history", func() {
history := []types.Message{
@@ -271,16 +335,21 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Content: "answer 1",
},
}
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
messages = createMessages(history, query)
body, err = createBody(messages, true)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
testValidHTTPResponse(subject, history, body)
})
})
})
when("ListModels()", func() {
it("throws an error when the http callout fails", func() {
subject := factory.buildClientWithoutConfig()
errorMsg := "error message"
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg))
@@ -289,6 +358,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
subject := factory.buildClientWithoutConfig()
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil)
_, err := subject.ListModels()
@@ -296,6 +367,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).To(Equal("empty response"))
})
it("throws an error when the response is a malformed json", func() {
subject := factory.buildClientWithoutConfig()
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil)
@@ -304,6 +377,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
})
it("filters gpt models as expected", func() {
subject := factory.buildClientWithoutConfig()
response, err := utils.FileToBytes("models.json")
Expect(err).NotTo(HaveOccurred())
@@ -319,8 +394,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
when("ProvideContext()", func() {
it("updates the history with the provided context", func() {
subject := factory.buildClientWithoutConfig()
context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
subject.ProvideContext(context)
Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context
@@ -336,9 +414,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
}
func createBody(messages []types.Message, stream bool) ([]byte, error) {
func createBody(messages []types.Message, model string, stream bool) ([]byte, error) {
req := types.CompletionsRequest{
Model: client.DefaultGPTModel,
Model: model,
Messages: messages,
Stream: stream,
}
@@ -365,3 +443,35 @@ func createMessages(history []types.Message, query string) []types.Message {
return messages
}
type clientFactory struct {
mockCaller *MockCaller
mockConfigStore *MockConfigStore
mockHistoryStore *MockHistoryStore
}
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
return &clientFactory{
mockCaller: mc,
mockConfigStore: mcs,
mockHistoryStore: mhs,
}
}
func (f *clientFactory) buildClientWithoutConfig() *client.Client {
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
}
func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
}
func (f *clientFactory) withoutHistory() {
f.mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
}
func (f *clientFactory) withHistory(history []types.Message) {
f.mockHistoryStore.EXPECT().Read().Return(history, nil).Times(1)
}

View File

@@ -0,0 +1,64 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/kardolus/chatgpt-cli/config (interfaces: ConfigStore)
// Package client_test is a generated GoMock package.
package client_test
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
types "github.com/kardolus/chatgpt-cli/types"
)
// MockConfigStore is a mock of ConfigStore interface.
type MockConfigStore struct {
ctrl *gomock.Controller
recorder *MockConfigStoreMockRecorder
}
// MockConfigStoreMockRecorder is the mock recorder for MockConfigStore.
type MockConfigStoreMockRecorder struct {
mock *MockConfigStore
}
// NewMockConfigStore creates a new mock instance.
func NewMockConfigStore(ctrl *gomock.Controller) *MockConfigStore {
mock := &MockConfigStore{ctrl: ctrl}
mock.recorder = &MockConfigStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
return m.recorder
}
// Read mocks base method.
func (m *MockConfigStore) Read() (types.Config, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read")
ret0, _ := ret[0].(types.Config)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read))
}
// Write mocks base method.
func (m *MockConfigStore) Write(arg0 types.Config) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Write indicates an expected call of Write.
func (mr *MockConfigStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConfigStore)(nil).Write), arg0)
}

View File

@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store)
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: HistoryStore)
// Package client_test is a generated GoMock package.
package client_test
@@ -11,31 +11,31 @@ import (
types "github.com/kardolus/chatgpt-cli/types"
)
// MockStore is a mock of Store interface.
type MockStore struct {
// MockHistoryStore is a mock of HistoryStore interface.
type MockHistoryStore struct {
ctrl *gomock.Controller
recorder *MockStoreMockRecorder
recorder *MockHistoryStoreMockRecorder
}
// MockStoreMockRecorder is the mock recorder for MockStore.
type MockStoreMockRecorder struct {
mock *MockStore
// MockHistoryStoreMockRecorder is the mock recorder for MockHistoryStore.
type MockHistoryStoreMockRecorder struct {
mock *MockHistoryStore
}
// NewMockStore creates a new mock instance.
func NewMockStore(ctrl *gomock.Controller) *MockStore {
mock := &MockStore{ctrl: ctrl}
mock.recorder = &MockStoreMockRecorder{mock}
// NewMockHistoryStore creates a new mock instance.
func NewMockHistoryStore(ctrl *gomock.Controller) *MockHistoryStore {
mock := &MockHistoryStore{ctrl: ctrl}
mock.recorder = &MockHistoryStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
func (m *MockHistoryStore) EXPECT() *MockHistoryStoreMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockStore) Delete() error {
func (m *MockHistoryStore) Delete() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete")
ret0, _ := ret[0].(error)
@@ -43,13 +43,13 @@ func (m *MockStore) Delete() error {
}
// Delete indicates an expected call of Delete.
func (mr *MockStoreMockRecorder) Delete() *gomock.Call {
func (mr *MockHistoryStoreMockRecorder) Delete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockHistoryStore)(nil).Delete))
}
// Read mocks base method.
func (m *MockStore) Read() ([]types.Message, error) {
func (m *MockHistoryStore) Read() ([]types.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read")
ret0, _ := ret[0].([]types.Message)
@@ -58,13 +58,13 @@ func (m *MockStore) Read() ([]types.Message, error) {
}
// Read indicates an expected call of Read.
func (mr *MockStoreMockRecorder) Read() *gomock.Call {
func (mr *MockHistoryStoreMockRecorder) Read() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockHistoryStore)(nil).Read))
}
// Write mocks base method.
func (m *MockStore) Write(arg0 []types.Message) error {
func (m *MockHistoryStore) Write(arg0 []types.Message) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(error)
@@ -72,7 +72,7 @@ func (m *MockStore) Write(arg0 []types.Message) error {
}
// Write indicates an expected call of Write.
func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
func (mr *MockHistoryStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockHistoryStore)(nil).Write), arg0)
}