Add additional configuration options

- This includes temperature, top_p, frequency_penalty and
  presence_penalty
This commit is contained in:
kardolus
2023-10-29 12:45:21 -04:00
parent c5c942b418
commit 158b25acd1
8 changed files with 501 additions and 108 deletions

View File

@@ -160,9 +160,13 @@ func (c *Client) Stream(input string) error {
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.CompletionsRequest{
Messages: c.History,
Model: c.Config.Model,
Stream: stream,
Messages: c.History,
Model: c.Config.Model,
Temperature: c.Config.Temperature,
TopP: c.Config.TopP,
FrequencyPenalty: c.Config.FrequencyPenalty,
PresencePenalty: c.Config.PresencePenalty,
Stream: stream,
}
return json.Marshal(body)

View File

@@ -22,15 +22,19 @@ import (
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore
const (
defaultMaxTokens = 4096
defaultURL = "https://default.openai.com"
defaultName = "default-name"
defaultModel = "gpt-3.5-turbo"
defaultCompletionsPath = "/default/completions"
defaultModelsPath = "/default/models"
defaultThread = "default-thread"
defaultRole = "You are a great default-role"
envApiKey = "api-key"
defaultMaxTokens = 4096
defaultURL = "https://default.openai.com"
defaultName = "default-name"
defaultModel = "gpt-3.5-turbo"
defaultCompletionsPath = "/default/completions"
defaultModelsPath = "/default/models"
defaultThread = "default-thread"
defaultRole = "You are a great default-role"
defaultTemperature = 1.1
defaultTopP = 2.2
defaultFrequencyPenalty = 3.3
defaultPresencePenalty = 4.4
envApiKey = "api-key"
)
var (
@@ -140,7 +144,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Config.Model, false)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
respBytes, err := tt.setupPostReturn()
@@ -192,17 +196,26 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(answer))
}
it("uses the model specified by the configuration instead of the default model", func() {
const model = "overwritten"
it("uses the values specified by the configuration instead of the default values", func() {
const (
model = "overwritten"
temperature = 100.1
topP = 200.2
frequencyPenalty = 300.3
presencePenalty = 400.4
)
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithConfig(types.Config{
Model: model,
Model: model,
Temperature: temperature,
TopP: topP,
FrequencyPenalty: frequencyPenalty,
PresencePenalty: presencePenalty,
})
body, err = createBody(messages, model, false)
body, err = createBodyWithConfig(messages, false, model, temperature, topP, frequencyPenalty, presencePenalty)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body, false)
})
@@ -225,7 +238,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
body, err = createBody(messages, subject.Config.Model, false)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body, false)
@@ -244,7 +257,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Config.Model, false)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body, true)
@@ -289,7 +302,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
// messages get truncated. Index 1+2 are cut out
messages = append(messages[:1], messages[3:]...)
body, err = createBody(messages, subject.Config.Model, false)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body, false)
@@ -308,7 +321,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Config.Model, true)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
errorMsg := "error message"
@@ -323,7 +336,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Config.Model, true)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, true).Return([]byte(answer), nil)
@@ -344,7 +357,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Config.Model, true)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
@@ -368,7 +381,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
messages = createMessages(history, query)
body, err = createBody(messages, subject.Config.Model, true)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body)
@@ -443,11 +456,29 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
}
func createBody(messages []types.Message, model string, stream bool) ([]byte, error) {
func createBody(messages []types.Message, stream bool) ([]byte, error) {
req := types.CompletionsRequest{
Model: model,
Messages: messages,
Stream: stream,
Model: defaultModel,
Messages: messages,
Stream: stream,
Temperature: defaultTemperature,
TopP: defaultTopP,
FrequencyPenalty: defaultFrequencyPenalty,
PresencePenalty: defaultPresencePenalty,
}
return json.Marshal(req)
}
func createBodyWithConfig(messages []types.Message, stream bool, model string, temperature float64, topP float64, frequencyPenalty float64, presencePenalty float64) ([]byte, error) {
req := types.CompletionsRequest{
Model: model,
Messages: messages,
Stream: stream,
Temperature: temperature,
TopP: topP,
FrequencyPenalty: frequencyPenalty,
PresencePenalty: presencePenalty,
}
return json.Marshal(req)
@@ -481,14 +512,18 @@ type clientFactory struct {
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
mockConfigStore.EXPECT().ReadDefaults().Return(types.Config{
Name: defaultName,
Model: defaultModel,
MaxTokens: defaultMaxTokens,
URL: defaultURL,
CompletionsPath: defaultCompletionsPath,
ModelsPath: defaultModelsPath,
Role: defaultRole,
Thread: defaultThread,
Name: defaultName,
Model: defaultModel,
MaxTokens: defaultMaxTokens,
URL: defaultURL,
CompletionsPath: defaultCompletionsPath,
ModelsPath: defaultModelsPath,
Role: defaultRole,
Thread: defaultThread,
Temperature: defaultTemperature,
PresencePenalty: defaultPresencePenalty,
TopP: defaultTopP,
FrequencyPenalty: defaultFrequencyPenalty,
}).Times(1)
return &clientFactory{