mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add the --set-model flag
This commit is contained in:
12
README.md
12
README.md
@@ -33,7 +33,7 @@ environment, demonstrating its practicality and effectiveness.
|
||||
limits.
|
||||
* **Custom context from local files**: Provide custom context through piping for GPT model reference during
|
||||
conversation.
|
||||
* **Custom chat models**: Use a custom chat model by specifying the model name with the `-m` or `--model` flag.
|
||||
* **Custom chat models**: Use a custom chat model by specifying the model name with the `--set-model` flag. Ensure that the model exists in the OpenAI model list.
|
||||
* **Model listing**: Get a list of available models by using the `-l` or `--list-models` flag.
|
||||
* **Viper integration**: Robust configuration management.
|
||||
|
||||
@@ -131,7 +131,15 @@ Then, use the pipe feature to provide this context to ChatGPT:
|
||||
cat context.txt | chatgpt "What kind of toy would Kya enjoy?"
|
||||
```
|
||||
|
||||
6. To list all available models, use the -l or --list-models flag:
|
||||
6. To set a specific model, use the `--set-model` flag followed by the model name:
|
||||
|
||||
```shell
|
||||
chatgpt --set-model gpt-3.5-turbo-0301
|
||||
```
|
||||
|
||||
Remember to check that the model exists in the OpenAI model list before setting it.
|
||||
|
||||
7. To list all available models, use the -l or --list-models flag:
|
||||
|
||||
```shell
|
||||
chatgpt --list-models
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kardolus/chatgpt-cli/config"
|
||||
"github.com/kardolus/chatgpt-cli/configmanager"
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
"github.com/kardolus/chatgpt-cli/http"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
@@ -26,20 +28,28 @@ const (
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
History []types.Message
|
||||
caller http.Caller
|
||||
capacity int
|
||||
model string
|
||||
readWriter history.Store
|
||||
History []types.Message
|
||||
Model string
|
||||
caller http.Caller
|
||||
capacity int
|
||||
historyStore history.HistoryStore
|
||||
}
|
||||
|
||||
func New(caller http.Caller, rw history.Store) *Client {
|
||||
return &Client{
|
||||
caller: caller,
|
||||
readWriter: rw,
|
||||
capacity: MaxTokenSize,
|
||||
model: DefaultGPTModel,
|
||||
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client {
|
||||
result := &Client{
|
||||
caller: caller,
|
||||
historyStore: hs,
|
||||
capacity: MaxTokenSize,
|
||||
}
|
||||
|
||||
// do not error out when the config cannot be read
|
||||
result.Model, _ = configmanager.New(cs).ReadModel()
|
||||
|
||||
if result.Model == "" {
|
||||
result.Model = DefaultGPTModel
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *Client) WithCapacity(capacity int) *Client {
|
||||
@@ -48,7 +58,7 @@ func (c *Client) WithCapacity(capacity int) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) WithModel(model string) *Client {
|
||||
c.model = model
|
||||
c.Model = model
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -73,7 +83,7 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
|
||||
for _, model := range response.Data {
|
||||
if strings.HasPrefix(model.Id, gptPrefix) {
|
||||
if model.Id != DefaultGPTModel {
|
||||
if model.Id != c.Model {
|
||||
result = append(result, fmt.Sprintf("- %s", model.Id))
|
||||
continue
|
||||
}
|
||||
@@ -156,7 +166,7 @@ func (c *Client) Stream(input string) error {
|
||||
func (c *Client) createBody(stream bool) ([]byte, error) {
|
||||
body := types.CompletionsRequest{
|
||||
Messages: c.History,
|
||||
Model: c.model,
|
||||
Model: c.Model,
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
@@ -168,7 +178,7 @@ func (c *Client) initHistory() {
|
||||
return
|
||||
}
|
||||
|
||||
c.History, _ = c.readWriter.Read()
|
||||
c.History, _ = c.historyStore.Read()
|
||||
if len(c.History) == 0 {
|
||||
c.History = []types.Message{{
|
||||
Role: SystemRole,
|
||||
@@ -232,7 +242,7 @@ func (c *Client) updateHistory(response string) {
|
||||
Role: AssistantRole,
|
||||
Content: response,
|
||||
})
|
||||
_ = c.readWriter.Write(c.History)
|
||||
_ = c.historyStore.Write(c.History)
|
||||
}
|
||||
|
||||
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {
|
||||
|
||||
@@ -16,13 +16,15 @@ import (
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
|
||||
//go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
|
||||
//go:generate mockgen -destination=historymocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history HistoryStore
|
||||
//go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore
|
||||
|
||||
var (
|
||||
mockCtrl *gomock.Controller
|
||||
mockCaller *MockCaller
|
||||
mockStore *MockStore
|
||||
subject *client.Client
|
||||
mockCtrl *gomock.Controller
|
||||
mockCaller *MockCaller
|
||||
mockHistoryStore *MockHistoryStore
|
||||
mockConfigStore *MockConfigStore
|
||||
factory *clientFactory
|
||||
)
|
||||
|
||||
func TestUnitClient(t *testing.T) {
|
||||
@@ -36,8 +38,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
RegisterTestingT(t)
|
||||
mockCtrl = gomock.NewController(t)
|
||||
mockCaller = NewMockCaller(mockCtrl)
|
||||
mockStore = NewMockStore(mockCtrl)
|
||||
subject = client.New(mockCaller, mockStore).WithCapacity(50)
|
||||
mockHistoryStore = NewMockHistoryStore(mockCtrl)
|
||||
mockConfigStore = NewMockConfigStore(mockCtrl)
|
||||
factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore)
|
||||
})
|
||||
|
||||
it.After(func() {
|
||||
@@ -51,63 +54,75 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
err error
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
type TestCase struct {
|
||||
description string
|
||||
setupPostReturn func() ([]byte, error)
|
||||
postError error
|
||||
expectedError string
|
||||
}
|
||||
|
||||
it("throws an error when the http callout fails", func() {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
tests := []TestCase{
|
||||
{
|
||||
description: "throws an error when the http callout fails",
|
||||
setupPostReturn: func() ([]byte, error) { return nil, nil },
|
||||
postError: errors.New("error message"),
|
||||
expectedError: "error message",
|
||||
},
|
||||
{
|
||||
description: "throws an error when the response is empty",
|
||||
setupPostReturn: func() ([]byte, error) { return nil, nil },
|
||||
postError: nil,
|
||||
expectedError: "empty response",
|
||||
},
|
||||
{
|
||||
description: "throws an error when the response is a malformed json",
|
||||
setupPostReturn: func() ([]byte, error) {
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
return []byte(malformed), nil
|
||||
},
|
||||
postError: nil,
|
||||
expectedError: "failed to decode response:",
|
||||
},
|
||||
{
|
||||
description: "throws an error when the response is missing Choices",
|
||||
setupPostReturn: func() ([]byte, error) {
|
||||
response := &types.CompletionsResponse{
|
||||
ID: "id",
|
||||
Object: "object",
|
||||
Created: 0,
|
||||
Model: "model",
|
||||
Choices: []types.Choice{},
|
||||
}
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, errors.New(errorMsg))
|
||||
respBytes, err := json.Marshal(response)
|
||||
return respBytes, err
|
||||
},
|
||||
postError: nil,
|
||||
expectedError: "no responses returned",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(errorMsg))
|
||||
})
|
||||
it("throws an error when the response is empty", func() {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, nil)
|
||||
for _, tt := range tests {
|
||||
it(tt.description, func() {
|
||||
factory.withoutHistory()
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("empty response"))
|
||||
})
|
||||
it("throws an error when the response is a malformed json", func() {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return([]byte(malformed), nil)
|
||||
respBytes, err := tt.setupPostReturn()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, tt.postError)
|
||||
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
|
||||
})
|
||||
it("throws an error when the response is missing Choices", func() {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(tt.expectedError))
|
||||
})
|
||||
}
|
||||
|
||||
response := &types.CompletionsResponse{
|
||||
ID: "id",
|
||||
Object: "object",
|
||||
Created: 0,
|
||||
Model: "model",
|
||||
Choices: []types.Choice{},
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("no responses returned"))
|
||||
})
|
||||
when("a valid http response is received", func() {
|
||||
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
|
||||
mockStore.EXPECT().Read().Return(history, nil).Times(1)
|
||||
|
||||
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
|
||||
const answer = "content"
|
||||
|
||||
choice := types.Choice{
|
||||
@@ -122,7 +137,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
ID: "id",
|
||||
Object: "object",
|
||||
Created: 0,
|
||||
Model: client.DefaultGPTModel,
|
||||
Model: subject.Model,
|
||||
Choices: []types.Choice{choice},
|
||||
}
|
||||
|
||||
@@ -134,7 +149,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
err = json.Unmarshal(expectedBody, &request)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockStore.EXPECT().Write(append(request.Messages, types.Message{
|
||||
mockHistoryStore.EXPECT().Write(append(request.Messages, types.Message{
|
||||
Role: client.AssistantRole,
|
||||
Content: answer,
|
||||
}))
|
||||
@@ -144,8 +159,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(result).To(Equal(answer))
|
||||
}
|
||||
|
||||
it("returns the expected result for an empty history", func() {
|
||||
testValidHTTPResponse(nil, body)
|
||||
it("uses the model specified by the WithModel method instead of the default model", func() {
|
||||
const model = "overwritten"
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
factory.withoutHistory()
|
||||
subject := factory.buildClientWithoutConfig().WithModel(model)
|
||||
|
||||
body, err = createBody(messages, model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
testValidHTTPResponse(subject, nil, body)
|
||||
})
|
||||
it("uses the model specified by the configuration instead of the default model", func() {
|
||||
const model = "overwritten"
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
factory.withoutHistory()
|
||||
subject := factory.buildClientWithConfig(types.Config{
|
||||
Model: model,
|
||||
})
|
||||
|
||||
body, err = createBody(messages, model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
testValidHTTPResponse(subject, nil, body)
|
||||
})
|
||||
it("when WithModel is used and a configuration is present, WithModel takes precedence", func() {
|
||||
const model = "with-model"
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
factory.withoutHistory()
|
||||
subject := factory.buildClientWithConfig(types.Config{
|
||||
Model: "config-model",
|
||||
}).WithModel(model)
|
||||
|
||||
body, err = createBody(messages, model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
testValidHTTPResponse(subject, nil, body)
|
||||
})
|
||||
it("returns the expected result for a non-empty history", func() {
|
||||
history := []types.Message{
|
||||
@@ -163,10 +212,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
},
|
||||
}
|
||||
messages = createMessages(history, query)
|
||||
body, err = createBody(messages, false)
|
||||
factory.withHistory(history)
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
body, err = createBody(messages, subject.Model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(history, body)
|
||||
testValidHTTPResponse(subject, history, body)
|
||||
})
|
||||
it("truncates the history as expected", func() {
|
||||
history := []types.Message{
|
||||
@@ -202,13 +254,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
messages = createMessages(history, query)
|
||||
|
||||
factory.withHistory(history)
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
// messages get truncated. Index 1+2 are cut out
|
||||
messages = append(messages[:1], messages[3:]...)
|
||||
|
||||
body, err = createBody(messages, false)
|
||||
body, err = createBody(messages, subject.Model, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(history, body)
|
||||
testValidHTTPResponse(subject, history, body)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -219,14 +274,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
err error
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
it("throws an error when the http callout fails", func() {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
factory.withoutHistory()
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg))
|
||||
@@ -238,13 +292,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
when("a valid http response is received", func() {
|
||||
const answer = "answer"
|
||||
|
||||
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
|
||||
mockStore.EXPECT().Read().Return(history, nil).Times(1)
|
||||
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) {
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil)
|
||||
|
||||
messages = createMessages(history, query)
|
||||
|
||||
mockStore.EXPECT().Write(append(messages, types.Message{
|
||||
mockHistoryStore.EXPECT().Write(append(messages, types.Message{
|
||||
Role: client.AssistantRole,
|
||||
Content: answer,
|
||||
}))
|
||||
@@ -254,7 +311,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
}
|
||||
|
||||
it("returns the expected result for an empty history", func() {
|
||||
testValidHTTPResponse(nil, body)
|
||||
factory.withHistory(nil)
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(subject, nil, body)
|
||||
})
|
||||
it("returns the expected result for a non-empty history", func() {
|
||||
history := []types.Message{
|
||||
@@ -271,16 +335,21 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Content: "answer 1",
|
||||
},
|
||||
}
|
||||
factory.withHistory(history)
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
messages = createMessages(history, query)
|
||||
body, err = createBody(messages, true)
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
testValidHTTPResponse(history, body)
|
||||
testValidHTTPResponse(subject, history, body)
|
||||
})
|
||||
})
|
||||
})
|
||||
when("ListModels()", func() {
|
||||
it("throws an error when the http callout fails", func() {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg))
|
||||
|
||||
@@ -289,6 +358,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err.Error()).To(Equal(errorMsg))
|
||||
})
|
||||
it("throws an error when the response is empty", func() {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil)
|
||||
|
||||
_, err := subject.ListModels()
|
||||
@@ -296,6 +367,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err.Error()).To(Equal("empty response"))
|
||||
})
|
||||
it("throws an error when the response is a malformed json", func() {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil)
|
||||
|
||||
@@ -304,6 +377,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
|
||||
})
|
||||
it("filters gpt models as expected", func() {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
response, err := utils.FileToBytes("models.json")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
@@ -319,8 +394,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
})
|
||||
when("ProvideContext()", func() {
|
||||
it("updates the history with the provided context", func() {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
|
||||
subject.ProvideContext(context)
|
||||
|
||||
Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context
|
||||
@@ -336,9 +414,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
})
|
||||
}
|
||||
|
||||
func createBody(messages []types.Message, stream bool) ([]byte, error) {
|
||||
func createBody(messages []types.Message, model string, stream bool) ([]byte, error) {
|
||||
req := types.CompletionsRequest{
|
||||
Model: client.DefaultGPTModel,
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: stream,
|
||||
}
|
||||
@@ -365,3 +443,35 @@ func createMessages(history []types.Message, query string) []types.Message {
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
type clientFactory struct {
|
||||
mockCaller *MockCaller
|
||||
mockConfigStore *MockConfigStore
|
||||
mockHistoryStore *MockHistoryStore
|
||||
}
|
||||
|
||||
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
|
||||
return &clientFactory{
|
||||
mockCaller: mc,
|
||||
mockConfigStore: mcs,
|
||||
mockHistoryStore: mhs,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *clientFactory) buildClientWithoutConfig() *client.Client {
|
||||
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
|
||||
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
|
||||
}
|
||||
|
||||
func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
|
||||
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
|
||||
return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50)
|
||||
}
|
||||
|
||||
func (f *clientFactory) withoutHistory() {
|
||||
f.mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
}
|
||||
|
||||
func (f *clientFactory) withHistory(history []types.Message) {
|
||||
f.mockHistoryStore.EXPECT().Read().Return(history, nil).Times(1)
|
||||
}
|
||||
|
||||
64
client/configmocks_test.go
Normal file
64
client/configmocks_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/kardolus/chatgpt-cli/config (interfaces: ConfigStore)
|
||||
|
||||
// Package client_test is a generated GoMock package.
|
||||
package client_test
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
types "github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
// MockConfigStore is a mock of ConfigStore interface.
|
||||
type MockConfigStore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConfigStoreMockRecorder
|
||||
}
|
||||
|
||||
// MockConfigStoreMockRecorder is the mock recorder for MockConfigStore.
|
||||
type MockConfigStoreMockRecorder struct {
|
||||
mock *MockConfigStore
|
||||
}
|
||||
|
||||
// NewMockConfigStore creates a new mock instance.
|
||||
func NewMockConfigStore(ctrl *gomock.Controller) *MockConfigStore {
|
||||
mock := &MockConfigStore{ctrl: ctrl}
|
||||
mock.recorder = &MockConfigStoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Read mocks base method.
|
||||
func (m *MockConfigStore) Read() (types.Config, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Read")
|
||||
ret0, _ := ret[0].(types.Config)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read.
|
||||
func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read))
|
||||
}
|
||||
|
||||
// Write mocks base method.
|
||||
func (m *MockConfigStore) Write(arg0 types.Config) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write.
|
||||
func (mr *MockConfigStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConfigStore)(nil).Write), arg0)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store)
|
||||
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: HistoryStore)
|
||||
|
||||
// Package client_test is a generated GoMock package.
|
||||
package client_test
|
||||
@@ -11,31 +11,31 @@ import (
|
||||
types "github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
// MockStore is a mock of Store interface.
|
||||
type MockStore struct {
|
||||
// MockHistoryStore is a mock of HistoryStore interface.
|
||||
type MockHistoryStore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStoreMockRecorder
|
||||
recorder *MockHistoryStoreMockRecorder
|
||||
}
|
||||
|
||||
// MockStoreMockRecorder is the mock recorder for MockStore.
|
||||
type MockStoreMockRecorder struct {
|
||||
mock *MockStore
|
||||
// MockHistoryStoreMockRecorder is the mock recorder for MockHistoryStore.
|
||||
type MockHistoryStoreMockRecorder struct {
|
||||
mock *MockHistoryStore
|
||||
}
|
||||
|
||||
// NewMockStore creates a new mock instance.
|
||||
func NewMockStore(ctrl *gomock.Controller) *MockStore {
|
||||
mock := &MockStore{ctrl: ctrl}
|
||||
mock.recorder = &MockStoreMockRecorder{mock}
|
||||
// NewMockHistoryStore creates a new mock instance.
|
||||
func NewMockHistoryStore(ctrl *gomock.Controller) *MockHistoryStore {
|
||||
mock := &MockHistoryStore{ctrl: ctrl}
|
||||
mock.recorder = &MockHistoryStoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
|
||||
func (m *MockHistoryStore) EXPECT() *MockHistoryStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockStore) Delete() error {
|
||||
func (m *MockHistoryStore) Delete() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete")
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -43,13 +43,13 @@ func (m *MockStore) Delete() error {
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockStoreMockRecorder) Delete() *gomock.Call {
|
||||
func (mr *MockHistoryStoreMockRecorder) Delete() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockHistoryStore)(nil).Delete))
|
||||
}
|
||||
|
||||
// Read mocks base method.
|
||||
func (m *MockStore) Read() ([]types.Message, error) {
|
||||
func (m *MockHistoryStore) Read() ([]types.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Read")
|
||||
ret0, _ := ret[0].([]types.Message)
|
||||
@@ -58,13 +58,13 @@ func (m *MockStore) Read() ([]types.Message, error) {
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read.
|
||||
func (mr *MockStoreMockRecorder) Read() *gomock.Call {
|
||||
func (mr *MockHistoryStoreMockRecorder) Read() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockHistoryStore)(nil).Read))
|
||||
}
|
||||
|
||||
// Write mocks base method.
|
||||
func (m *MockStore) Write(arg0 []types.Message) error {
|
||||
func (m *MockHistoryStore) Write(arg0 []types.Message) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -72,7 +72,7 @@ func (m *MockStore) Write(arg0 []types.Message) error {
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write.
|
||||
func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockHistoryStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockHistoryStore)(nil).Write), arg0)
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kardolus/chatgpt-cli/client"
|
||||
"github.com/kardolus/chatgpt-cli/config"
|
||||
"github.com/kardolus/chatgpt-cli/configmanager"
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
"github.com/kardolus/chatgpt-cli/http"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -43,6 +45,7 @@ func main() {
|
||||
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information")
|
||||
rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models")
|
||||
rootCmd.PersistentFlags().StringVarP(&modelName, "model", "m", "", "Use a custom GPT model by specifying the model name")
|
||||
rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name")
|
||||
|
||||
viper.AutomaticEnv()
|
||||
|
||||
@@ -57,7 +60,7 @@ func run(cmd *cobra.Command, args []string) error {
|
||||
if secret == "" {
|
||||
return errors.New("missing environment variable: " + secretEnv)
|
||||
}
|
||||
client := client.New(http.New().WithSecret(secret), history.New())
|
||||
client := client.New(http.New().WithSecret(secret), config.New(), history.New())
|
||||
|
||||
if modelName != "" {
|
||||
client = client.WithModel(modelName)
|
||||
@@ -78,6 +81,14 @@ func run(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cmd.Flag("set-model").Changed {
|
||||
if err := configmanager.New(config.New()).WriteModel(modelName); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Model successfully updated to", modelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
if clearHistory {
|
||||
historyHandler := history.New()
|
||||
err := historyHandler.Delete()
|
||||
|
||||
69
config/config.go
Normal file
69
config/config.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
"gopkg.in/yaml.v3"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type ConfigStore interface {
|
||||
Read() (types.Config, error)
|
||||
Write(types.Config) error
|
||||
}
|
||||
|
||||
// Ensure FileIO implements ConfigStore interface
|
||||
var _ ConfigStore = &FileIO{}
|
||||
|
||||
type FileIO struct {
|
||||
configFilePath string
|
||||
}
|
||||
|
||||
func New() *FileIO {
|
||||
path, _ := getPath()
|
||||
return &FileIO{
|
||||
configFilePath: path,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FileIO) WithFilePath(configFilePath string) *FileIO {
|
||||
f.configFilePath = configFilePath
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *FileIO) Read() (types.Config, error) {
|
||||
return parseFile(f.configFilePath)
|
||||
}
|
||||
|
||||
func (f *FileIO) Write(config types.Config) error {
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(f.configFilePath, data, 0644)
|
||||
}
|
||||
|
||||
func getPath() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(homeDir, ".chatgpt-cli", "config.yaml"), nil
|
||||
}
|
||||
|
||||
func parseFile(fileName string) (types.Config, error) {
|
||||
var result types.Config
|
||||
|
||||
buf, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return types.Config{}, err
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(buf, &result); err != nil {
|
||||
return types.Config{}, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
27
configmanager/configmanager.go
Normal file
27
configmanager/configmanager.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package configmanager
|
||||
|
||||
import (
|
||||
"github.com/kardolus/chatgpt-cli/config"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
type ConfigManager struct {
|
||||
configStore config.ConfigStore
|
||||
}
|
||||
|
||||
func New(cs config.ConfigStore) *ConfigManager {
|
||||
return &ConfigManager{configStore: cs}
|
||||
}
|
||||
|
||||
func (c *ConfigManager) ReadModel() (string, error) {
|
||||
conf, err := c.configStore.Read()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return conf.Model, nil
|
||||
}
|
||||
|
||||
func (c *ConfigManager) WriteModel(model string) error {
|
||||
return c.configStore.Write(types.Config{Model: model})
|
||||
}
|
||||
62
configmanager/configmanager_test.go
Normal file
62
configmanager/configmanager_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package configmanager_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/kardolus/chatgpt-cli/configmanager"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sclevine/spec"
|
||||
"github.com/sclevine/spec/report"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=configmocks_test.go -package=configmanager_test github.com/kardolus/chatgpt-cli/config ConfigStore
|
||||
|
||||
func TestUnitConfigManager(t *testing.T) {
|
||||
spec.Run(t, "Testing the Config Manager", testConfig, spec.Report(report.Terminal{}))
|
||||
}
|
||||
|
||||
func testConfig(t *testing.T, when spec.G, it spec.S) {
|
||||
var (
|
||||
mockCtrl *gomock.Controller
|
||||
mockConfigStore *MockConfigStore
|
||||
subject *configmanager.ConfigManager
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
RegisterTestingT(t)
|
||||
mockCtrl = gomock.NewController(t)
|
||||
mockConfigStore = NewMockConfigStore(mockCtrl)
|
||||
subject = configmanager.New(mockConfigStore)
|
||||
})
|
||||
|
||||
it.After(func() {
|
||||
mockCtrl.Finish()
|
||||
})
|
||||
|
||||
when("ReadModel()", func() {
|
||||
it("throws an error when the config file does not exist", func() {
|
||||
expectedErrorMsg := "file not found"
|
||||
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New(expectedErrorMsg)).Times(1)
|
||||
_, err := subject.ReadModel()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(expectedErrorMsg))
|
||||
})
|
||||
it("parses a config file as expected", func() {
|
||||
modelName := "the-model"
|
||||
mockConfigStore.EXPECT().Read().Return(types.Config{Model: modelName}, nil).Times(1)
|
||||
|
||||
result, err := subject.ReadModel()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal(modelName))
|
||||
})
|
||||
})
|
||||
when("WriteModel()", func() {
|
||||
it("writes the expected config file", func() {
|
||||
modelName := "the-model"
|
||||
mockConfigStore.EXPECT().Write(types.Config{Model: modelName}).Times(1)
|
||||
Expect(subject.WriteModel(modelName)).To(Succeed())
|
||||
})
|
||||
})
|
||||
}
|
||||
64
configmanager/configmocks_test.go
Normal file
64
configmanager/configmocks_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/kardolus/chatgpt-cli/config (interfaces: ConfigStore)
|
||||
|
||||
// Package configmanager_test is a generated GoMock package.
|
||||
package configmanager_test
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
types "github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
// MockConfigStore is a mock of ConfigStore interface.
|
||||
type MockConfigStore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConfigStoreMockRecorder
|
||||
}
|
||||
|
||||
// MockConfigStoreMockRecorder is the mock recorder for MockConfigStore.
|
||||
type MockConfigStoreMockRecorder struct {
|
||||
mock *MockConfigStore
|
||||
}
|
||||
|
||||
// NewMockConfigStore creates a new mock instance.
|
||||
func NewMockConfigStore(ctrl *gomock.Controller) *MockConfigStore {
|
||||
mock := &MockConfigStore{ctrl: ctrl}
|
||||
mock.recorder = &MockConfigStoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Read mocks base method.
|
||||
func (m *MockConfigStore) Read() (types.Config, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Read")
|
||||
ret0, _ := ret[0].(types.Config)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read.
|
||||
func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read))
|
||||
}
|
||||
|
||||
// Write mocks base method.
|
||||
func (m *MockConfigStore) Write(arg0 types.Config) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write.
|
||||
func (mr *MockConfigStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConfigStore)(nil).Write), arg0)
|
||||
}
|
||||
@@ -3,19 +3,18 @@ package history
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
type HistoryStore interface {
|
||||
Delete() error
|
||||
Read() ([]types.Message, error)
|
||||
Write([]types.Message) error
|
||||
}
|
||||
|
||||
// Ensure RestCaller implements Caller interface
|
||||
var _ Store = &FileIO{}
|
||||
// Ensure FileIO implements the HistoryStore interface
|
||||
var _ HistoryStore = &FileIO{}
|
||||
|
||||
type FileIO struct {
|
||||
historyFilePath string
|
||||
@@ -28,7 +27,7 @@ func New() *FileIO {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FileIO) WithHistory(historyFilePath string) *FileIO {
|
||||
func (f *FileIO) WithFilePath(historyFilePath string) *FileIO {
|
||||
f.historyFilePath = historyFilePath
|
||||
return f
|
||||
}
|
||||
@@ -51,7 +50,7 @@ func (f *FileIO) Write(messages []types.Message) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(f.historyFilePath, data, 0644)
|
||||
return os.WriteFile(f.historyFilePath, data, 0644)
|
||||
}
|
||||
|
||||
func getPath() (string, error) {
|
||||
@@ -66,7 +65,7 @@ func getPath() (string, error) {
|
||||
func parseFile(fileName string) ([]types.Message, error) {
|
||||
var result []types.Message
|
||||
|
||||
buf, err := ioutil.ReadFile(fileName)
|
||||
buf, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"github.com/kardolus/chatgpt-cli/config"
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
"github.com/sclevine/spec"
|
||||
"github.com/sclevine/spec/report"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -21,7 +21,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
|
||||
RegisterTestingT(t)
|
||||
})
|
||||
|
||||
when("Read, Write and Delete", func() {
|
||||
when("Read, Write and Delete History", func() {
|
||||
var (
|
||||
tmpDir string
|
||||
tmpFile *os.File
|
||||
@@ -31,15 +31,15 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
tmpDir, err = ioutil.TempDir("", "chatgpt-cli-test")
|
||||
tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
tmpFile, err = ioutil.TempFile(tmpDir, "history.json")
|
||||
tmpFile, err = os.CreateTemp(tmpDir, "history.json")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
tmpFile.Close()
|
||||
Expect(tmpFile.Close()).To(Succeed())
|
||||
|
||||
fileIO = history.New().WithHistory(tmpFile.Name())
|
||||
fileIO = history.New().WithFilePath(tmpFile.Name())
|
||||
|
||||
messages = []types.Message{
|
||||
{
|
||||
@@ -54,7 +54,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
|
||||
})
|
||||
|
||||
it.After(func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
Expect(os.RemoveAll(tmpDir)).To(Succeed())
|
||||
})
|
||||
|
||||
it("writes the messages to the file", func() {
|
||||
@@ -79,4 +79,62 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(os.IsNotExist(err)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
when("Read, Write Config", func() {
|
||||
var (
|
||||
tmpDir string
|
||||
tmpFile *os.File
|
||||
configIO *config.FileIO
|
||||
testConfig types.Config
|
||||
err error
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
tmpFile, err = os.CreateTemp(tmpDir, "config.yaml")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(tmpFile.Close()).To(Succeed())
|
||||
|
||||
configIO = config.New().WithFilePath(tmpFile.Name())
|
||||
|
||||
testConfig = types.Config{
|
||||
Model: "test-model",
|
||||
}
|
||||
})
|
||||
|
||||
it.After(func() {
|
||||
Expect(os.RemoveAll(tmpDir)).To(Succeed())
|
||||
})
|
||||
|
||||
it("writes the config to the file", func() {
|
||||
err = configIO.Write(testConfig)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
it("reads the config from the file", func() {
|
||||
err = configIO.Write(testConfig) // need to write before reading
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
readConfig, err := configIO.Read()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(readConfig).To(Equal(testConfig))
|
||||
})
|
||||
|
||||
// Since we don't have a Delete method in the config, we will test if we can overwrite the configuration.
|
||||
it("overwrites the existing config", func() {
|
||||
newConfig := types.Config{
|
||||
Model: "new-model",
|
||||
}
|
||||
err = configIO.Write(newConfig)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
readConfig, err := configIO.Read()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(readConfig).To(Equal(newConfig))
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
5
types/config.go
Normal file
5
types/config.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package types
|
||||
|
||||
type Config struct {
|
||||
Model string `yaml:"model"`
|
||||
}
|
||||
Reference in New Issue
Block a user