mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Read default values from a config file
This commit is contained in:
@@ -17,56 +17,46 @@ const (
|
||||
AssistantContent = "You are a helpful assistant."
|
||||
AssistantRole = "assistant"
|
||||
ErrEmptyResponse = "empty response"
|
||||
DefaultGPTModel = "gpt-3.5-turbo"
|
||||
DefaultServiceURL = "https://api.openai.com"
|
||||
CompletionPath = "/v1/chat/completions"
|
||||
ModelPath = "/v1/models"
|
||||
MaxTokenBufferPercentage = 20
|
||||
MaxTokenSize = 4096
|
||||
SystemRole = "system"
|
||||
UserRole = "user"
|
||||
gptPrefix = "gpt"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Config types.Config
|
||||
History []types.Message
|
||||
Model string
|
||||
caller http.Caller
|
||||
capacity int
|
||||
historyStore history.HistoryStore
|
||||
serviceURL string
|
||||
}
|
||||
|
||||
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client {
|
||||
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) (*Client, error) {
|
||||
cm, err := configmanager.New(cs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &Client{
|
||||
Config: cm.Config,
|
||||
caller: caller,
|
||||
historyStore: hs,
|
||||
capacity: MaxTokenSize,
|
||||
serviceURL: DefaultServiceURL,
|
||||
}
|
||||
|
||||
// do not error out when the config cannot be read
|
||||
result.Model, _ = configmanager.New(cs).ReadModel()
|
||||
|
||||
if result.Model == "" {
|
||||
result.Model = DefaultGPTModel
|
||||
}
|
||||
|
||||
return result
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *Client) WithCapacity(capacity int) *Client {
|
||||
c.capacity = capacity
|
||||
c.Config.MaxTokens = capacity
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) WithModel(model string) *Client {
|
||||
c.Model = model
|
||||
c.Config.Model = model
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) WithServiceURL(url string) *Client {
|
||||
c.serviceURL = url
|
||||
c.Config.URL = url
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -79,7 +69,7 @@ func (c *Client) WithServiceURL(url string) *Client {
|
||||
func (c *Client) ListModels() ([]string, error) {
|
||||
var result []string
|
||||
|
||||
raw, err := c.caller.Get(c.getEndpoint(ModelPath))
|
||||
raw, err := c.caller.Get(c.getEndpoint(c.Config.ModelsPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -91,7 +81,7 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
|
||||
for _, model := range response.Data {
|
||||
if strings.HasPrefix(model.Id, gptPrefix) {
|
||||
if model.Id != c.Model {
|
||||
if model.Id != c.Config.Model {
|
||||
result = append(result, fmt.Sprintf("- %s", model.Id))
|
||||
continue
|
||||
}
|
||||
@@ -129,7 +119,7 @@ func (c *Client) Query(input string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
raw, err := c.caller.Post(c.getEndpoint(CompletionPath), body, false)
|
||||
raw, err := c.caller.Post(c.getEndpoint(c.Config.CompletionsPath), body, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -161,7 +151,7 @@ func (c *Client) Stream(input string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := c.caller.Post(c.getEndpoint(CompletionPath), body, true)
|
||||
result, err := c.caller.Post(c.getEndpoint(c.Config.CompletionsPath), body, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -174,7 +164,7 @@ func (c *Client) Stream(input string) error {
|
||||
func (c *Client) createBody(stream bool) ([]byte, error) {
|
||||
body := types.CompletionsRequest{
|
||||
Messages: c.History,
|
||||
Model: c.Model,
|
||||
Model: c.Config.Model,
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
@@ -206,7 +196,7 @@ func (c *Client) addQuery(query string) {
|
||||
}
|
||||
|
||||
func (c *Client) getEndpoint(path string) string {
|
||||
return c.serviceURL + path
|
||||
return c.Config.URL + path
|
||||
}
|
||||
|
||||
func (c *Client) prepareQuery(input string) {
|
||||
@@ -228,7 +218,7 @@ func (c *Client) processResponse(raw []byte, v interface{}) error {
|
||||
|
||||
func (c *Client) truncateHistory() {
|
||||
tokens, rolling := countTokens(c.History)
|
||||
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)
|
||||
effectiveTokenSize := calculateEffectiveTokenSize(c.Config.MaxTokens, MaxTokenBufferPercentage)
|
||||
|
||||
if tokens <= effectiveTokenSize {
|
||||
return
|
||||
|
||||
@@ -19,6 +19,14 @@ import (
|
||||
//go:generate mockgen -destination=historymocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history HistoryStore
|
||||
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore
|
||||
|
||||
const (
|
||||
defaultMaxTokens = 4096
|
||||
defaultURL = "https://api.openai.com"
|
||||
defaultModel = "gpt-3.5-turbo"
|
||||
defaultCompletionsPath = "/v1/chat/completions"
|
||||
defaultModelsPath = "/v1/models"
|
||||
)
|
||||
|
||||
var (
|
||||
mockCtrl *gomock.Controller
|
||||
mockCaller *MockCaller
|
||||
@@ -40,6 +48,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
mockCaller = NewMockCaller(mockCtrl)
|
||||
mockHistoryStore = NewMockHistoryStore(mockCtrl)
|
||||
mockConfigStore = NewMockConfigStore(mockCtrl)
|
||||
|
||||
factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore)
|
||||
})
|
||||
|
||||
@@ -108,12 +117,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, false)
|
||||
body, err = createBody(messages, subject.Config.Model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
respBytes, err := tt.setupPostReturn()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, false).Return(respBytes, tt.postError)
|
||||
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, body, false).Return(respBytes, tt.postError)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -137,13 +146,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
ID: "id",
|
||||
Object: "object",
|
||||
Created: 0,
|
||||
Model: subject.Model,
|
||||
Model: subject.Config.Model,
|
||||
Choices: []types.Choice{choice},
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, false).Return(respBytes, nil)
|
||||
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(respBytes, nil)
|
||||
|
||||
var request types.CompletionsRequest
|
||||
err = json.Unmarshal(expectedBody, &request)
|
||||
@@ -215,7 +224,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
factory.withHistory(history)
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
body, err = createBody(messages, subject.Model, false)
|
||||
body, err = createBody(messages, subject.Config.Model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, history, body)
|
||||
@@ -260,7 +269,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
// messages get truncated. Index 1+2 are cut out
|
||||
messages = append(messages[:1], messages[3:]...)
|
||||
|
||||
body, err = createBody(messages, subject.Model, false)
|
||||
body, err = createBody(messages, subject.Config.Model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, history, body)
|
||||
@@ -279,11 +288,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, true).Return(nil, errors.New(errorMsg))
|
||||
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, body, true).Return(nil, errors.New(errorMsg))
|
||||
|
||||
err := subject.Stream(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -294,10 +303,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, true).Return([]byte(answer), nil)
|
||||
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, true).Return([]byte(answer), nil)
|
||||
|
||||
messages = createMessages(history, query)
|
||||
|
||||
@@ -315,7 +324,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, nil, body)
|
||||
@@ -339,7 +348,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(history, query)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, history, body)
|
||||
@@ -351,7 +360,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, errors.New(errorMsg))
|
||||
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(nil, errors.New(errorMsg))
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -360,7 +369,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
it("throws an error when the response is empty", func() {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, nil)
|
||||
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(nil, nil)
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -370,7 +379,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return([]byte(malformed), nil)
|
||||
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return([]byte(malformed), nil)
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -382,7 +391,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
response, err := utils.FileToBytes("models.json")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(response, nil)
|
||||
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(response, nil)
|
||||
|
||||
result, err := subject.ListModels()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@@ -451,6 +460,14 @@ type clientFactory struct {
|
||||
}
|
||||
|
||||
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
|
||||
mockConfigStore.EXPECT().ReadDefaults().Return(types.Config{
|
||||
Model: defaultModel,
|
||||
MaxTokens: defaultMaxTokens,
|
||||
URL: defaultURL,
|
||||
CompletionsPath: defaultCompletionsPath,
|
||||
ModelsPath: defaultModelsPath,
|
||||
}, nil).Times(1)
|
||||
|
||||
return &clientFactory{
|
||||
mockCaller: mc,
|
||||
mockConfigStore: mcs,
|
||||
@@ -460,12 +477,20 @@ func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStor
|
||||
|
||||
func (f *clientFactory) buildClientWithoutConfig() *client.Client {
|
||||
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
|
||||
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
|
||||
|
||||
c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
return c.WithCapacity(50)
|
||||
}
|
||||
|
||||
func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
|
||||
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
|
||||
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
|
||||
|
||||
c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
return c.WithCapacity(50)
|
||||
}
|
||||
|
||||
func (f *clientFactory) withoutHistory() {
|
||||
|
||||
@@ -49,6 +49,21 @@ func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read))
|
||||
}
|
||||
|
||||
// ReadDefaults mocks base method.
|
||||
func (m *MockConfigStore) ReadDefaults() (types.Config, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ReadDefaults")
|
||||
ret0, _ := ret[0].(types.Config)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ReadDefaults indicates an expected call of ReadDefaults.
|
||||
func (mr *MockConfigStoreMockRecorder) ReadDefaults() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadDefaults", reflect.TypeOf((*MockConfigStore)(nil).ReadDefaults))
|
||||
}
|
||||
|
||||
// Write mocks base method.
|
||||
func (m *MockConfigStore) Write(arg0 types.Config) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Reference in New Issue
Block a user