mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add additional configuration options
- This includes temperature, top_p, frequency_penalty and presence_penalty
This commit is contained in:
@@ -160,9 +160,13 @@ func (c *Client) Stream(input string) error {
|
||||
|
||||
func (c *Client) createBody(stream bool) ([]byte, error) {
|
||||
body := types.CompletionsRequest{
|
||||
Messages: c.History,
|
||||
Model: c.Config.Model,
|
||||
Stream: stream,
|
||||
Messages: c.History,
|
||||
Model: c.Config.Model,
|
||||
Temperature: c.Config.Temperature,
|
||||
TopP: c.Config.TopP,
|
||||
FrequencyPenalty: c.Config.FrequencyPenalty,
|
||||
PresencePenalty: c.Config.PresencePenalty,
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
return json.Marshal(body)
|
||||
|
||||
@@ -22,15 +22,19 @@ import (
|
||||
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore
|
||||
|
||||
const (
|
||||
defaultMaxTokens = 4096
|
||||
defaultURL = "https://default.openai.com"
|
||||
defaultName = "default-name"
|
||||
defaultModel = "gpt-3.5-turbo"
|
||||
defaultCompletionsPath = "/default/completions"
|
||||
defaultModelsPath = "/default/models"
|
||||
defaultThread = "default-thread"
|
||||
defaultRole = "You are a great default-role"
|
||||
envApiKey = "api-key"
|
||||
defaultMaxTokens = 4096
|
||||
defaultURL = "https://default.openai.com"
|
||||
defaultName = "default-name"
|
||||
defaultModel = "gpt-3.5-turbo"
|
||||
defaultCompletionsPath = "/default/completions"
|
||||
defaultModelsPath = "/default/models"
|
||||
defaultThread = "default-thread"
|
||||
defaultRole = "You are a great default-role"
|
||||
defaultTemperature = 1.1
|
||||
defaultTopP = 2.2
|
||||
defaultFrequencyPenalty = 3.3
|
||||
defaultPresencePenalty = 4.4
|
||||
envApiKey = "api-key"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -140,7 +144,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Config.Model, false)
|
||||
body, err = createBody(messages, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
respBytes, err := tt.setupPostReturn()
|
||||
@@ -192,17 +196,26 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal(answer))
|
||||
}
|
||||
|
||||
it("uses the model specified by the configuration instead of the default model", func() {
|
||||
const model = "overwritten"
|
||||
it("uses the values specified by the configuration instead of the default values", func() {
|
||||
const (
|
||||
model = "overwritten"
|
||||
temperature = 100.1
|
||||
topP = 200.2
|
||||
frequencyPenalty = 300.3
|
||||
presencePenalty = 400.4
|
||||
)
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
factory.withoutHistory()
|
||||
subject := factory.buildClientWithConfig(types.Config{
|
||||
Model: model,
|
||||
Model: model,
|
||||
Temperature: temperature,
|
||||
TopP: topP,
|
||||
FrequencyPenalty: frequencyPenalty,
|
||||
PresencePenalty: presencePenalty,
|
||||
})
|
||||
|
||||
body, err = createBody(messages, model, false)
|
||||
body, err = createBodyWithConfig(messages, false, model, temperature, topP, frequencyPenalty, presencePenalty)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
testValidHTTPResponse(subject, nil, body, false)
|
||||
})
|
||||
@@ -225,7 +238,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
factory.withHistory(history)
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
body, err = createBody(messages, subject.Config.Model, false)
|
||||
body, err = createBody(messages, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, history, body, false)
|
||||
@@ -244,7 +257,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
|
||||
body, err = createBody(messages, subject.Config.Model, false)
|
||||
body, err = createBody(messages, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, nil, body, true)
|
||||
@@ -289,7 +302,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
// messages get truncated. Index 1+2 are cut out
|
||||
messages = append(messages[:1], messages[3:]...)
|
||||
|
||||
body, err = createBody(messages, subject.Config.Model, false)
|
||||
body, err = createBody(messages, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, history, body, false)
|
||||
@@ -308,7 +321,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
body, err = createBody(messages, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
errorMsg := "error message"
|
||||
@@ -323,7 +336,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
body, err = createBody(messages, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, true).Return([]byte(answer), nil)
|
||||
@@ -344,7 +357,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
body, err = createBody(messages, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, nil, body)
|
||||
@@ -368,7 +381,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(history, query)
|
||||
body, err = createBody(messages, subject.Config.Model, true)
|
||||
body, err = createBody(messages, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, history, body)
|
||||
@@ -443,11 +456,29 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
})
|
||||
}
|
||||
|
||||
func createBody(messages []types.Message, model string, stream bool) ([]byte, error) {
|
||||
func createBody(messages []types.Message, stream bool) ([]byte, error) {
|
||||
req := types.CompletionsRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: stream,
|
||||
Model: defaultModel,
|
||||
Messages: messages,
|
||||
Stream: stream,
|
||||
Temperature: defaultTemperature,
|
||||
TopP: defaultTopP,
|
||||
FrequencyPenalty: defaultFrequencyPenalty,
|
||||
PresencePenalty: defaultPresencePenalty,
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func createBodyWithConfig(messages []types.Message, stream bool, model string, temperature float64, topP float64, frequencyPenalty float64, presencePenalty float64) ([]byte, error) {
|
||||
req := types.CompletionsRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: stream,
|
||||
Temperature: temperature,
|
||||
TopP: topP,
|
||||
FrequencyPenalty: frequencyPenalty,
|
||||
PresencePenalty: presencePenalty,
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
@@ -481,14 +512,18 @@ type clientFactory struct {
|
||||
|
||||
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
|
||||
mockConfigStore.EXPECT().ReadDefaults().Return(types.Config{
|
||||
Name: defaultName,
|
||||
Model: defaultModel,
|
||||
MaxTokens: defaultMaxTokens,
|
||||
URL: defaultURL,
|
||||
CompletionsPath: defaultCompletionsPath,
|
||||
ModelsPath: defaultModelsPath,
|
||||
Role: defaultRole,
|
||||
Thread: defaultThread,
|
||||
Name: defaultName,
|
||||
Model: defaultModel,
|
||||
MaxTokens: defaultMaxTokens,
|
||||
URL: defaultURL,
|
||||
CompletionsPath: defaultCompletionsPath,
|
||||
ModelsPath: defaultModelsPath,
|
||||
Role: defaultRole,
|
||||
Thread: defaultThread,
|
||||
Temperature: defaultTemperature,
|
||||
PresencePenalty: defaultPresencePenalty,
|
||||
TopP: defaultTopP,
|
||||
FrequencyPenalty: defaultFrequencyPenalty,
|
||||
}).Times(1)
|
||||
|
||||
return &clientFactory{
|
||||
|
||||
Reference in New Issue
Block a user