Add the --set-model flag

This commit is contained in:
kardolus
2023-06-08 11:14:51 -04:00
parent b0708c5bfa
commit 1ea1b684fb
13 changed files with 621 additions and 134 deletions

View File

@@ -33,7 +33,7 @@ environment, demonstrating its practicality and effectiveness.
limits.
* **Custom context from local files**: Provide custom context through piping for GPT model reference during
conversation.
* **Custom chat models**: Use a custom chat model by specifying the model name with the `-m` or `--model` flag.
* **Custom chat models**: Use a custom chat model by specifying the model name with the `--set-model` flag. Ensure that the model exists in the OpenAI model list.
* **Model listing**: Get a list of available models by using the `-l` or `--list-models` flag.
* **Viper integration**: Robust configuration management.
@@ -131,7 +131,15 @@ Then, use the pipe feature to provide this context to ChatGPT:
cat context.txt | chatgpt "What kind of toy would Kya enjoy?"
```
6. To list all available models, use the -l or --list-models flag:
6. To set a specific model, use the `--set-model` flag followed by the model name:
```shell
chatgpt --set-model gpt-3.5-turbo-0301
```
Remember to check that the model exists in the OpenAI model list before setting it.
7. To list all available models, use the -l or --list-models flag:
```shell
chatgpt --list-models

View File

@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types"
@@ -26,20 +28,28 @@ const (
)
type Client struct {
History []types.Message
caller http.Caller
capacity int
model string
readWriter history.Store
History []types.Message
Model string
caller http.Caller
capacity int
historyStore history.HistoryStore
}
func New(caller http.Caller, rw history.Store) *Client {
return &Client{
caller: caller,
readWriter: rw,
capacity: MaxTokenSize,
model: DefaultGPTModel,
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client {
result := &Client{
caller: caller,
historyStore: hs,
capacity: MaxTokenSize,
}
// do not error out when the config cannot be read
result.Model, _ = configmanager.New(cs).ReadModel()
if result.Model == "" {
result.Model = DefaultGPTModel
}
return result
}
func (c *Client) WithCapacity(capacity int) *Client {
@@ -48,7 +58,7 @@ func (c *Client) WithCapacity(capacity int) *Client {
}
func (c *Client) WithModel(model string) *Client {
c.model = model
c.Model = model
return c
}
@@ -73,7 +83,7 @@ func (c *Client) ListModels() ([]string, error) {
for _, model := range response.Data {
if strings.HasPrefix(model.Id, gptPrefix) {
if model.Id != DefaultGPTModel {
if model.Id != c.Model {
result = append(result, fmt.Sprintf("- %s", model.Id))
continue
}
@@ -156,7 +166,7 @@ func (c *Client) Stream(input string) error {
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.CompletionsRequest{
Messages: c.History,
Model: c.model,
Model: c.Model,
Stream: stream,
}
@@ -168,7 +178,7 @@ func (c *Client) initHistory() {
return
}
c.History, _ = c.readWriter.Read()
c.History, _ = c.historyStore.Read()
if len(c.History) == 0 {
c.History = []types.Message{{
Role: SystemRole,
@@ -232,7 +242,7 @@ func (c *Client) updateHistory(response string) {
Role: AssistantRole,
Content: response,
})
_ = c.readWriter.Write(c.History)
_ = c.historyStore.Write(c.History)
}
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {

View File

@@ -16,13 +16,15 @@ import (
)
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
//go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
//go:generate mockgen -destination=historymocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history HistoryStore
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore
var (
mockCtrl *gomock.Controller
mockCaller *MockCaller
mockStore *MockStore
subject *client.Client
mockCtrl *gomock.Controller
mockCaller *MockCaller
mockHistoryStore *MockHistoryStore
mockConfigStore *MockConfigStore
factory *clientFactory
)
func TestUnitClient(t *testing.T) {
@@ -36,8 +38,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
RegisterTestingT(t)
mockCtrl = gomock.NewController(t)
mockCaller = NewMockCaller(mockCtrl)
mockStore = NewMockStore(mockCtrl)
subject = client.New(mockCaller, mockStore).WithCapacity(50)
mockHistoryStore = NewMockHistoryStore(mockCtrl)
mockConfigStore = NewMockConfigStore(mockCtrl)
factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore)
})
it.After(func() {
@@ -51,63 +54,75 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
err error
)
it.Before(func() {
messages = createMessages(nil, query)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
})
type TestCase struct {
description string
setupPostReturn func() ([]byte, error)
postError error
expectedError string
}
it("throws an error when the http callout fails", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
tests := []TestCase{
{
description: "throws an error when the http callout fails",
setupPostReturn: func() ([]byte, error) { return nil, nil },
postError: errors.New("error message"),
expectedError: "error message",
},
{
description: "throws an error when the response is empty",
setupPostReturn: func() ([]byte, error) { return nil, nil },
postError: nil,
expectedError: "empty response",
},
{
description: "throws an error when the response is a malformed json",
setupPostReturn: func() ([]byte, error) {
malformed := `{"invalid":"json"` // missing closing brace
return []byte(malformed), nil
},
postError: nil,
expectedError: "failed to decode response:",
},
{
description: "throws an error when the response is missing Choices",
setupPostReturn: func() ([]byte, error) {
response := &types.CompletionsResponse{
ID: "id",
Object: "object",
Created: 0,
Model: "model",
Choices: []types.Choice{},
}
errorMsg := "error message"
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, errors.New(errorMsg))
respBytes, err := json.Marshal(response)
return respBytes, err
},
postError: nil,
expectedError: "no responses returned",
},
}
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, nil)
for _, tt := range tests {
it(tt.description, func() {
factory.withoutHistory()
subject := factory.buildClientWithoutConfig()
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("empty response"))
})
it("throws an error when the response is a malformed json", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, false)
Expect(err).NotTo(HaveOccurred())
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return([]byte(malformed), nil)
respBytes, err := tt.setupPostReturn()
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, tt.postError)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
})
it("throws an error when the response is missing Choices", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(tt.expectedError))
})
}
response := &types.CompletionsResponse{
ID: "id",
Object: "object",
Created: 0,
Model: "model",
Choices: []types.Choice{},
}
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("no responses returned"))
})
when("a valid http response is received", func() {
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
mockStore.EXPECT().Read().Return(history, nil).Times(1)
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
const answer = "content"
choice := types.Choice{
@@ -122,7 +137,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
ID: "id",
Object: "object",
Created: 0,
Model: client.DefaultGPTModel,
Model: subject.Model,
Choices: []types.Choice{choice},
}
@@ -134,7 +149,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
err = json.Unmarshal(expectedBody, &request)
Expect(err).NotTo(HaveOccurred())
mockStore.EXPECT().Write(append(request.Messages, types.Message{
mockHistoryStore.EXPECT().Write(append(request.Messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
@@ -144,8 +159,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(result).To(Equal(answer))
}
it("returns the expected result for an empty history", func() {
testValidHTTPResponse(nil, body)
it("uses the model specified by the WithModel method instead of the default model", func() {
const model = "overwritten"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithoutConfig().WithModel(model)
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("uses the model specified by the configuration instead of the default model", func() {
const model = "overwritten"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithConfig(types.Config{
Model: model,
})
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("when WithModel is used and a configuration is present, WithModel takes precedence", func() {
const model = "with-model"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithConfig(types.Config{
Model: "config-model",
}).WithModel(model)
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("returns the expected result for a non-empty history", func() {
history := []types.Message{
@@ -163,10 +212,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
},
}
messages = createMessages(history, query)
body, err = createBody(messages, false)
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
body, err = createBody(messages, subject.Model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
testValidHTTPResponse(subject, history, body)
})
it("truncates the history as expected", func() {
history := []types.Message{
@@ -202,13 +254,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
messages = createMessages(history, query)
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
// messages get truncated. Index 1+2 are cut out
messages = append(messages[:1], messages[3:]...)
body, err = createBody(messages, false)
body, err = createBody(messages, subject.Model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
testValidHTTPResponse(subject, history, body)
})
})
})
@@ -219,14 +274,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
err error
)
it.Before(func() {
messages = createMessages(nil, query)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
})
it("throws an error when the http callout fails", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
factory.withoutHistory()
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
errorMsg := "error message"
mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg))
@@ -238,13 +292,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
when("a valid http response is received", func() {
const answer = "answer"
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
mockStore.EXPECT().Read().Return(history, nil).Times(1)
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil)
messages = createMessages(history, query)
mockStore.EXPECT().Write(append(messages, types.Message{
mockHistoryStore.EXPECT().Write(append(messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
@@ -254,7 +311,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
}
it("returns the expected result for an empty history", func() {
testValidHTTPResponse(nil, body)
factory.withHistory(nil)
subject := factory.buildClientWithoutConfig()
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("returns the expected result for a non-empty history", func() {
history := []types.Message{
@@ -271,16 +335,21 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Content: "answer 1",
},
}
factory.withHistory(history)
subject := factory.buildClientWithoutConfig()
messages = createMessages(history, query)
body, err = createBody(messages, true)
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
testValidHTTPResponse(subject, history, body)
})
})
})
when("ListModels()", func() {
it("throws an error when the http callout fails", func() {
subject := factory.buildClientWithoutConfig()
errorMsg := "error message"
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg))
@@ -289,6 +358,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
subject := factory.buildClientWithoutConfig()
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil)
_, err := subject.ListModels()
@@ -296,6 +367,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).To(Equal("empty response"))
})
it("throws an error when the response is a malformed json", func() {
subject := factory.buildClientWithoutConfig()
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil)
@@ -304,6 +377,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
})
it("filters gpt models as expected", func() {
subject := factory.buildClientWithoutConfig()
response, err := utils.FileToBytes("models.json")
Expect(err).NotTo(HaveOccurred())
@@ -319,8 +394,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
when("ProvideContext()", func() {
it("updates the history with the provided context", func() {
subject := factory.buildClientWithoutConfig()
context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
subject.ProvideContext(context)
Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context
@@ -336,9 +414,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
}
func createBody(messages []types.Message, stream bool) ([]byte, error) {
func createBody(messages []types.Message, model string, stream bool) ([]byte, error) {
req := types.CompletionsRequest{
Model: client.DefaultGPTModel,
Model: model,
Messages: messages,
Stream: stream,
}
@@ -365,3 +443,35 @@ func createMessages(history []types.Message, query string) []types.Message {
return messages
}
type clientFactory struct {
mockCaller *MockCaller
mockConfigStore *MockConfigStore
mockHistoryStore *MockHistoryStore
}
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
return &clientFactory{
mockCaller: mc,
mockConfigStore: mcs,
mockHistoryStore: mhs,
}
}
func (f *clientFactory) buildClientWithoutConfig() *client.Client {
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
}
func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
}
func (f *clientFactory) withoutHistory() {
f.mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
}
func (f *clientFactory) withHistory(history []types.Message) {
f.mockHistoryStore.EXPECT().Read().Return(history, nil).Times(1)
}

View File

@@ -0,0 +1,64 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/kardolus/chatgpt-cli/config (interfaces: ConfigStore)
// Package client_test is a generated GoMock package.
package client_test
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
types "github.com/kardolus/chatgpt-cli/types"
)
// MockConfigStore is a mock of ConfigStore interface.
type MockConfigStore struct {
ctrl *gomock.Controller
recorder *MockConfigStoreMockRecorder
}
// MockConfigStoreMockRecorder is the mock recorder for MockConfigStore.
type MockConfigStoreMockRecorder struct {
mock *MockConfigStore
}
// NewMockConfigStore creates a new mock instance.
func NewMockConfigStore(ctrl *gomock.Controller) *MockConfigStore {
mock := &MockConfigStore{ctrl: ctrl}
mock.recorder = &MockConfigStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
return m.recorder
}
// Read mocks base method.
func (m *MockConfigStore) Read() (types.Config, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read")
ret0, _ := ret[0].(types.Config)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read))
}
// Write mocks base method.
func (m *MockConfigStore) Write(arg0 types.Config) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Write indicates an expected call of Write.
func (mr *MockConfigStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConfigStore)(nil).Write), arg0)
}

View File

@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store)
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: HistoryStore)
// Package client_test is a generated GoMock package.
package client_test
@@ -11,31 +11,31 @@ import (
types "github.com/kardolus/chatgpt-cli/types"
)
// MockStore is a mock of Store interface.
type MockStore struct {
// MockHistoryStore is a mock of HistoryStore interface.
type MockHistoryStore struct {
ctrl *gomock.Controller
recorder *MockStoreMockRecorder
recorder *MockHistoryStoreMockRecorder
}
// MockStoreMockRecorder is the mock recorder for MockStore.
type MockStoreMockRecorder struct {
mock *MockStore
// MockHistoryStoreMockRecorder is the mock recorder for MockHistoryStore.
type MockHistoryStoreMockRecorder struct {
mock *MockHistoryStore
}
// NewMockStore creates a new mock instance.
func NewMockStore(ctrl *gomock.Controller) *MockStore {
mock := &MockStore{ctrl: ctrl}
mock.recorder = &MockStoreMockRecorder{mock}
// NewMockHistoryStore creates a new mock instance.
func NewMockHistoryStore(ctrl *gomock.Controller) *MockHistoryStore {
mock := &MockHistoryStore{ctrl: ctrl}
mock.recorder = &MockHistoryStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
func (m *MockHistoryStore) EXPECT() *MockHistoryStoreMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockStore) Delete() error {
func (m *MockHistoryStore) Delete() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete")
ret0, _ := ret[0].(error)
@@ -43,13 +43,13 @@ func (m *MockStore) Delete() error {
}
// Delete indicates an expected call of Delete.
func (mr *MockStoreMockRecorder) Delete() *gomock.Call {
func (mr *MockHistoryStoreMockRecorder) Delete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockHistoryStore)(nil).Delete))
}
// Read mocks base method.
func (m *MockStore) Read() ([]types.Message, error) {
func (m *MockHistoryStore) Read() ([]types.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read")
ret0, _ := ret[0].([]types.Message)
@@ -58,13 +58,13 @@ func (m *MockStore) Read() ([]types.Message, error) {
}
// Read indicates an expected call of Read.
func (mr *MockStoreMockRecorder) Read() *gomock.Call {
func (mr *MockHistoryStoreMockRecorder) Read() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockHistoryStore)(nil).Read))
}
// Write mocks base method.
func (m *MockStore) Write(arg0 []types.Message) error {
func (m *MockHistoryStore) Write(arg0 []types.Message) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(error)
@@ -72,7 +72,7 @@ func (m *MockStore) Write(arg0 []types.Message) error {
}
// Write indicates an expected call of Write.
func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
func (mr *MockHistoryStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockHistoryStore)(nil).Write), arg0)
}

View File

@@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"github.com/kardolus/chatgpt-cli/client"
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/spf13/cobra"
@@ -43,6 +45,7 @@ func main() {
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information")
rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models")
rootCmd.PersistentFlags().StringVarP(&modelName, "model", "m", "", "Use a custom GPT model by specifying the model name")
rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name")
viper.AutomaticEnv()
@@ -57,7 +60,7 @@ func run(cmd *cobra.Command, args []string) error {
if secret == "" {
return errors.New("missing environment variable: " + secretEnv)
}
client := client.New(http.New().WithSecret(secret), history.New())
client := client.New(http.New().WithSecret(secret), config.New(), history.New())
if modelName != "" {
client = client.WithModel(modelName)
@@ -78,6 +81,14 @@ func run(cmd *cobra.Command, args []string) error {
return nil
}
if cmd.Flag("set-model").Changed {
if err := configmanager.New(config.New()).WriteModel(modelName); err != nil {
return err
}
fmt.Println("Model successfully updated to", modelName)
return nil
}
if clearHistory {
historyHandler := history.New()
err := historyHandler.Delete()

69
config/config.go Normal file
View File

@@ -0,0 +1,69 @@
package config
import (
"github.com/kardolus/chatgpt-cli/types"
"gopkg.in/yaml.v3"
"os"
"path/filepath"
)
type ConfigStore interface {
Read() (types.Config, error)
Write(types.Config) error
}
// Ensure FileIO implements ConfigStore interface
var _ ConfigStore = &FileIO{}
type FileIO struct {
configFilePath string
}
func New() *FileIO {
path, _ := getPath()
return &FileIO{
configFilePath: path,
}
}
func (f *FileIO) WithFilePath(configFilePath string) *FileIO {
f.configFilePath = configFilePath
return f
}
func (f *FileIO) Read() (types.Config, error) {
return parseFile(f.configFilePath)
}
func (f *FileIO) Write(config types.Config) error {
data, err := yaml.Marshal(config)
if err != nil {
return err
}
return os.WriteFile(f.configFilePath, data, 0644)
}
func getPath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(homeDir, ".chatgpt-cli", "config.yaml"), nil
}
func parseFile(fileName string) (types.Config, error) {
var result types.Config
buf, err := os.ReadFile(fileName)
if err != nil {
return types.Config{}, err
}
if err := yaml.Unmarshal(buf, &result); err != nil {
return types.Config{}, err
}
return result, nil
}

View File

@@ -0,0 +1,27 @@
package configmanager
import (
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/types"
)
type ConfigManager struct {
configStore config.ConfigStore
}
func New(cs config.ConfigStore) *ConfigManager {
return &ConfigManager{configStore: cs}
}
func (c *ConfigManager) ReadModel() (string, error) {
conf, err := c.configStore.Read()
if err != nil {
return "", err
}
return conf.Model, nil
}
func (c *ConfigManager) WriteModel(model string) error {
return c.configStore.Write(types.Config{Model: model})
}

View File

@@ -0,0 +1,62 @@
package configmanager_test
import (
"errors"
"github.com/golang/mock/gomock"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/types"
. "github.com/onsi/gomega"
"github.com/sclevine/spec"
"github.com/sclevine/spec/report"
"testing"
)
//go:generate mockgen -destination=configmocks_test.go -package=configmanager_test github.com/kardolus/chatgpt-cli/config ConfigStore
func TestUnitConfigManager(t *testing.T) {
spec.Run(t, "Testing the Config Manager", testConfig, spec.Report(report.Terminal{}))
}
func testConfig(t *testing.T, when spec.G, it spec.S) {
var (
mockCtrl *gomock.Controller
mockConfigStore *MockConfigStore
subject *configmanager.ConfigManager
)
it.Before(func() {
RegisterTestingT(t)
mockCtrl = gomock.NewController(t)
mockConfigStore = NewMockConfigStore(mockCtrl)
subject = configmanager.New(mockConfigStore)
})
it.After(func() {
mockCtrl.Finish()
})
when("ReadModel()", func() {
it("throws an error when the config file does not exist", func() {
expectedErrorMsg := "file not found"
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New(expectedErrorMsg)).Times(1)
_, err := subject.ReadModel()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(expectedErrorMsg))
})
it("parses a config file as expected", func() {
modelName := "the-model"
mockConfigStore.EXPECT().Read().Return(types.Config{Model: modelName}, nil).Times(1)
result, err := subject.ReadModel()
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(modelName))
})
})
when("WriteModel()", func() {
it("writes the expected config file", func() {
modelName := "the-model"
mockConfigStore.EXPECT().Write(types.Config{Model: modelName}).Times(1)
Expect(subject.WriteModel(modelName)).To(Succeed())
})
})
}

View File

@@ -0,0 +1,64 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/kardolus/chatgpt-cli/config (interfaces: ConfigStore)
// Package configmanager_test is a generated GoMock package.
package configmanager_test
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
types "github.com/kardolus/chatgpt-cli/types"
)
// MockConfigStore is a mock of ConfigStore interface.
type MockConfigStore struct {
ctrl *gomock.Controller
recorder *MockConfigStoreMockRecorder
}
// MockConfigStoreMockRecorder is the mock recorder for MockConfigStore.
type MockConfigStoreMockRecorder struct {
mock *MockConfigStore
}
// NewMockConfigStore creates a new mock instance.
func NewMockConfigStore(ctrl *gomock.Controller) *MockConfigStore {
mock := &MockConfigStore{ctrl: ctrl}
mock.recorder = &MockConfigStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
return m.recorder
}
// Read mocks base method.
func (m *MockConfigStore) Read() (types.Config, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read")
ret0, _ := ret[0].(types.Config)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read))
}
// Write mocks base method.
func (m *MockConfigStore) Write(arg0 types.Config) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Write indicates an expected call of Write.
func (mr *MockConfigStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConfigStore)(nil).Write), arg0)
}

View File

@@ -3,19 +3,18 @@ package history
import (
"encoding/json"
"github.com/kardolus/chatgpt-cli/types"
"io/ioutil"
"os"
"path/filepath"
)
type Store interface {
type HistoryStore interface {
Delete() error
Read() ([]types.Message, error)
Write([]types.Message) error
}
// Ensure RestCaller implements Caller interface
var _ Store = &FileIO{}
// Ensure FileIO implements the HistoryStore interface
var _ HistoryStore = &FileIO{}
type FileIO struct {
historyFilePath string
@@ -28,7 +27,7 @@ func New() *FileIO {
}
}
func (f *FileIO) WithHistory(historyFilePath string) *FileIO {
func (f *FileIO) WithFilePath(historyFilePath string) *FileIO {
f.historyFilePath = historyFilePath
return f
}
@@ -51,7 +50,7 @@ func (f *FileIO) Write(messages []types.Message) error {
return err
}
return ioutil.WriteFile(f.historyFilePath, data, 0644)
return os.WriteFile(f.historyFilePath, data, 0644)
}
func getPath() (string, error) {
@@ -66,7 +65,7 @@ func getPath() (string, error) {
func parseFile(fileName string) ([]types.Message, error) {
var result []types.Message
buf, err := ioutil.ReadFile(fileName)
buf, err := os.ReadFile(fileName)
if err != nil {
return nil, err
}

View File

@@ -1,11 +1,11 @@
package integration_test
import (
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/types"
"github.com/sclevine/spec"
"github.com/sclevine/spec/report"
"io/ioutil"
"os"
"testing"
@@ -21,7 +21,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
RegisterTestingT(t)
})
when("Read, Write and Delete", func() {
when("Read, Write and Delete History", func() {
var (
tmpDir string
tmpFile *os.File
@@ -31,15 +31,15 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
)
it.Before(func() {
tmpDir, err = ioutil.TempDir("", "chatgpt-cli-test")
tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test")
Expect(err).NotTo(HaveOccurred())
tmpFile, err = ioutil.TempFile(tmpDir, "history.json")
tmpFile, err = os.CreateTemp(tmpDir, "history.json")
Expect(err).NotTo(HaveOccurred())
tmpFile.Close()
Expect(tmpFile.Close()).To(Succeed())
fileIO = history.New().WithHistory(tmpFile.Name())
fileIO = history.New().WithFilePath(tmpFile.Name())
messages = []types.Message{
{
@@ -54,7 +54,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
})
it.After(func() {
os.RemoveAll(tmpDir)
Expect(os.RemoveAll(tmpDir)).To(Succeed())
})
it("writes the messages to the file", func() {
@@ -79,4 +79,62 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(os.IsNotExist(err)).To(BeTrue())
})
})
when("Read, Write Config", func() {
var (
tmpDir string
tmpFile *os.File
configIO *config.FileIO
testConfig types.Config
err error
)
it.Before(func() {
tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test")
Expect(err).NotTo(HaveOccurred())
tmpFile, err = os.CreateTemp(tmpDir, "config.yaml")
Expect(err).NotTo(HaveOccurred())
Expect(tmpFile.Close()).To(Succeed())
configIO = config.New().WithFilePath(tmpFile.Name())
testConfig = types.Config{
Model: "test-model",
}
})
it.After(func() {
Expect(os.RemoveAll(tmpDir)).To(Succeed())
})
it("writes the config to the file", func() {
err = configIO.Write(testConfig)
Expect(err).NotTo(HaveOccurred())
})
it("reads the config from the file", func() {
err = configIO.Write(testConfig) // need to write before reading
Expect(err).NotTo(HaveOccurred())
readConfig, err := configIO.Read()
Expect(err).NotTo(HaveOccurred())
Expect(readConfig).To(Equal(testConfig))
})
// Since we don't have a Delete method in the config, we will test if we can overwrite the configuration.
it("overwrites the existing config", func() {
newConfig := types.Config{
Model: "new-model",
}
err = configIO.Write(newConfig)
Expect(err).NotTo(HaveOccurred())
readConfig, err := configIO.Read()
Expect(err).NotTo(HaveOccurred())
Expect(readConfig).To(Equal(newConfig))
})
})
}

5
types/config.go Normal file
View File

@@ -0,0 +1,5 @@
package types
type Config struct {
Model string `yaml:"model"`
}