Read default values from a config file

This commit is contained in:
kardolus
2023-06-17 09:53:26 -04:00
parent a5c9df3598
commit 9a512533a0
13 changed files with 333 additions and 88 deletions

View File

@@ -17,56 +17,46 @@ const (
AssistantContent = "You are a helpful assistant."
AssistantRole = "assistant"
ErrEmptyResponse = "empty response"
DefaultGPTModel = "gpt-3.5-turbo"
DefaultServiceURL = "https://api.openai.com"
CompletionPath = "/v1/chat/completions"
ModelPath = "/v1/models"
MaxTokenBufferPercentage = 20
MaxTokenSize = 4096
SystemRole = "system"
UserRole = "user"
gptPrefix = "gpt"
)
type Client struct {
Config types.Config
History []types.Message
Model string
caller http.Caller
capacity int
historyStore history.HistoryStore
serviceURL string
}
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client {
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) (*Client, error) {
cm, err := configmanager.New(cs)
if err != nil {
return nil, err
}
result := &Client{
Config: cm.Config,
caller: caller,
historyStore: hs,
capacity: MaxTokenSize,
serviceURL: DefaultServiceURL,
}
// do not error out when the config cannot be read
result.Model, _ = configmanager.New(cs).ReadModel()
if result.Model == "" {
result.Model = DefaultGPTModel
}
return result
return result, nil
}
func (c *Client) WithCapacity(capacity int) *Client {
c.capacity = capacity
c.Config.MaxTokens = capacity
return c
}
func (c *Client) WithModel(model string) *Client {
c.Model = model
c.Config.Model = model
return c
}
func (c *Client) WithServiceURL(url string) *Client {
c.serviceURL = url
c.Config.URL = url
return c
}
@@ -79,7 +69,7 @@ func (c *Client) WithServiceURL(url string) *Client {
func (c *Client) ListModels() ([]string, error) {
var result []string
raw, err := c.caller.Get(c.getEndpoint(ModelPath))
raw, err := c.caller.Get(c.getEndpoint(c.Config.ModelsPath))
if err != nil {
return nil, err
}
@@ -91,7 +81,7 @@ func (c *Client) ListModels() ([]string, error) {
for _, model := range response.Data {
if strings.HasPrefix(model.Id, gptPrefix) {
if model.Id != c.Model {
if model.Id != c.Config.Model {
result = append(result, fmt.Sprintf("- %s", model.Id))
continue
}
@@ -129,7 +119,7 @@ func (c *Client) Query(input string) (string, error) {
return "", err
}
raw, err := c.caller.Post(c.getEndpoint(CompletionPath), body, false)
raw, err := c.caller.Post(c.getEndpoint(c.Config.CompletionsPath), body, false)
if err != nil {
return "", err
}
@@ -161,7 +151,7 @@ func (c *Client) Stream(input string) error {
return err
}
result, err := c.caller.Post(c.getEndpoint(CompletionPath), body, true)
result, err := c.caller.Post(c.getEndpoint(c.Config.CompletionsPath), body, true)
if err != nil {
return err
}
@@ -174,7 +164,7 @@ func (c *Client) Stream(input string) error {
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.CompletionsRequest{
Messages: c.History,
Model: c.Model,
Model: c.Config.Model,
Stream: stream,
}
@@ -206,7 +196,7 @@ func (c *Client) addQuery(query string) {
}
func (c *Client) getEndpoint(path string) string {
return c.serviceURL + path
return c.Config.URL + path
}
func (c *Client) prepareQuery(input string) {
@@ -228,7 +218,7 @@ func (c *Client) processResponse(raw []byte, v interface{}) error {
func (c *Client) truncateHistory() {
tokens, rolling := countTokens(c.History)
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)
effectiveTokenSize := calculateEffectiveTokenSize(c.Config.MaxTokens, MaxTokenBufferPercentage)
if tokens <= effectiveTokenSize {
return

View File

@@ -19,6 +19,14 @@ import (
//go:generate mockgen -destination=historymocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history HistoryStore
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore
const (
defaultMaxTokens = 4096
defaultURL = "https://api.openai.com"
defaultModel = "gpt-3.5-turbo"
defaultCompletionsPath = "/v1/chat/completions"
defaultModelsPath = "/v1/models"
)
var (
mockCtrl *gomock.Controller
mockCaller *MockCaller
@@ -40,6 +48,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockCaller = NewMockCaller(mockCtrl)
mockHistoryStore = NewMockHistoryStore(mockCtrl)
mockConfigStore = NewMockConfigStore(mockCtrl)
factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore)
})
@@ -108,12 +117,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, false)
body, err = createBody(messages, subject.Config.Model, false)
Expect(err).NotTo(HaveOccurred())
respBytes, err := tt.setupPostReturn()
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, false).Return(respBytes, tt.postError)
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, body, false).Return(respBytes, tt.postError)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -137,13 +146,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
ID: "id",
Object: "object",
Created: 0,
Model: subject.Model,
Model: subject.Config.Model,
Choices: []types.Choice{choice},
}
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, false).Return(respBytes, nil)
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(respBytes, nil)
var request types.CompletionsRequest
err = json.Unmarshal(expectedBody, &request)
@@ -215,7 +224,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
body, err = createBody(messages, subject.Model, false)
body, err = createBody(messages, subject.Config.Model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body)
@@ -260,7 +269,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
// messages get truncated. Index 1+2 are cut out
messages = append(messages[:1], messages[3:]...)
body, err = createBody(messages, subject.Model, false)
body, err = createBody(messages, subject.Config.Model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body)
@@ -279,11 +288,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
body, err = createBody(messages, subject.Config.Model, true)
Expect(err).NotTo(HaveOccurred())
errorMsg := "error message"
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, true).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, body, true).Return(nil, errors.New(errorMsg))
err := subject.Stream(query)
Expect(err).To(HaveOccurred())
@@ -294,10 +303,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
body, err = createBody(messages, subject.Config.Model, true)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, true).Return([]byte(answer), nil)
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, true).Return([]byte(answer), nil)
messages = createMessages(history, query)
@@ -315,7 +324,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
body, err = createBody(messages, subject.Config.Model, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
@@ -339,7 +348,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(history, query)
body, err = createBody(messages, subject.Model, true)
body, err = createBody(messages, subject.Config.Model, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body)
@@ -351,7 +360,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
errorMsg := "error message"
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(nil, errors.New(errorMsg))
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
@@ -360,7 +369,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
it("throws an error when the response is empty", func() {
subject := factory.buildClientWithoutConfig()
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, nil)
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(nil, nil)
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
@@ -370,7 +379,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return([]byte(malformed), nil)
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return([]byte(malformed), nil)
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
@@ -382,7 +391,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
response, err := utils.FileToBytes("models.json")
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(response, nil)
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(response, nil)
result, err := subject.ListModels()
Expect(err).NotTo(HaveOccurred())
@@ -451,6 +460,14 @@ type clientFactory struct {
}
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
mockConfigStore.EXPECT().ReadDefaults().Return(types.Config{
Model: defaultModel,
MaxTokens: defaultMaxTokens,
URL: defaultURL,
CompletionsPath: defaultCompletionsPath,
ModelsPath: defaultModelsPath,
}, nil).Times(1)
return &clientFactory{
mockCaller: mc,
mockConfigStore: mcs,
@@ -460,12 +477,20 @@ func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStor
func (f *clientFactory) buildClientWithoutConfig() *client.Client {
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
return c.WithCapacity(50)
}
func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
return c.WithCapacity(50)
}
func (f *clientFactory) withoutHistory() {

View File

@@ -49,6 +49,21 @@ func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read))
}
// ReadDefaults mocks base method.
func (m *MockConfigStore) ReadDefaults() (types.Config, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadDefaults")
ret0, _ := ret[0].(types.Config)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadDefaults indicates an expected call of ReadDefaults.
func (mr *MockConfigStoreMockRecorder) ReadDefaults() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadDefaults", reflect.TypeOf((*MockConfigStore)(nil).ReadDefaults))
}
// Write mocks base method.
func (m *MockConfigStore) Write(arg0 types.Config) error {
m.ctrl.T.Helper()