mirror of
				https://github.com/kardolus/chatgpt-cli.git
				synced 2024-09-08 23:15:00 +03:00 
			
		
		
		
	Add the --set-model flag
This commit is contained in:
		
							
								
								
									
										12
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								README.md
									
									
									
									
									
								
							| @@ -33,7 +33,7 @@ environment, demonstrating its practicality and effectiveness. | ||||
|   limits. | ||||
| * **Custom context from local files**: Provide custom context through piping for GPT model reference during | ||||
|   conversation. | ||||
| * **Custom chat models**: Use a custom chat model by specifying the model name with the `-m` or `--model` flag. | ||||
| * **Custom chat models**: Use a custom chat model by specifying the model name with the `--set-model` flag. Ensure that the model exists in the OpenAI model list. | ||||
| * **Model listing**: Get a list of available models by using the `-l` or `--list-models` flag. | ||||
| * **Viper integration**: Robust configuration management. | ||||
|  | ||||
| @@ -131,7 +131,15 @@ Then, use the pipe feature to provide this context to ChatGPT: | ||||
| cat context.txt | chatgpt "What kind of toy would Kya enjoy?" | ||||
| ``` | ||||
|  | ||||
| 6. To list all available models, use the -l or --list-models flag: | ||||
| 6. To set a specific model, use the `--set-model` flag followed by the model name: | ||||
|  | ||||
| ```shell | ||||
| chatgpt --set-model gpt-3.5-turbo-0301 | ||||
| ``` | ||||
|  | ||||
| Remember to check that the model exists in the OpenAI model list before setting it. | ||||
|  | ||||
| 7. To list all available models, use the -l or --list-models flag: | ||||
|  | ||||
| ```shell | ||||
| chatgpt --list-models | ||||
|   | ||||
| @@ -4,6 +4,8 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/kardolus/chatgpt-cli/config" | ||||
| 	"github.com/kardolus/chatgpt-cli/configmanager" | ||||
| 	"github.com/kardolus/chatgpt-cli/history" | ||||
| 	"github.com/kardolus/chatgpt-cli/http" | ||||
| 	"github.com/kardolus/chatgpt-cli/types" | ||||
| @@ -26,20 +28,28 @@ const ( | ||||
| ) | ||||
|  | ||||
| type Client struct { | ||||
| 	History    []types.Message | ||||
| 	caller     http.Caller | ||||
| 	capacity   int | ||||
| 	model      string | ||||
| 	readWriter history.Store | ||||
| 	History      []types.Message | ||||
| 	Model        string | ||||
| 	caller       http.Caller | ||||
| 	capacity     int | ||||
| 	historyStore history.HistoryStore | ||||
| } | ||||
|  | ||||
| func New(caller http.Caller, rw history.Store) *Client { | ||||
| 	return &Client{ | ||||
| 		caller:     caller, | ||||
| 		readWriter: rw, | ||||
| 		capacity:   MaxTokenSize, | ||||
| 		model:      DefaultGPTModel, | ||||
| func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client { | ||||
| 	result := &Client{ | ||||
| 		caller:       caller, | ||||
| 		historyStore: hs, | ||||
| 		capacity:     MaxTokenSize, | ||||
| 	} | ||||
|  | ||||
| 	// do not error out when the config cannot be read | ||||
| 	result.Model, _ = configmanager.New(cs).ReadModel() | ||||
|  | ||||
| 	if result.Model == "" { | ||||
| 		result.Model = DefaultGPTModel | ||||
| 	} | ||||
|  | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| func (c *Client) WithCapacity(capacity int) *Client { | ||||
| @@ -48,7 +58,7 @@ func (c *Client) WithCapacity(capacity int) *Client { | ||||
| } | ||||
|  | ||||
| func (c *Client) WithModel(model string) *Client { | ||||
| 	c.model = model | ||||
| 	c.Model = model | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| @@ -73,7 +83,7 @@ func (c *Client) ListModels() ([]string, error) { | ||||
|  | ||||
| 	for _, model := range response.Data { | ||||
| 		if strings.HasPrefix(model.Id, gptPrefix) { | ||||
| 			if model.Id != DefaultGPTModel { | ||||
| 			if model.Id != c.Model { | ||||
| 				result = append(result, fmt.Sprintf("- %s", model.Id)) | ||||
| 				continue | ||||
| 			} | ||||
| @@ -156,7 +166,7 @@ func (c *Client) Stream(input string) error { | ||||
| func (c *Client) createBody(stream bool) ([]byte, error) { | ||||
| 	body := types.CompletionsRequest{ | ||||
| 		Messages: c.History, | ||||
| 		Model:    c.model, | ||||
| 		Model:    c.Model, | ||||
| 		Stream:   stream, | ||||
| 	} | ||||
|  | ||||
| @@ -168,7 +178,7 @@ func (c *Client) initHistory() { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	c.History, _ = c.readWriter.Read() | ||||
| 	c.History, _ = c.historyStore.Read() | ||||
| 	if len(c.History) == 0 { | ||||
| 		c.History = []types.Message{{ | ||||
| 			Role:    SystemRole, | ||||
| @@ -232,7 +242,7 @@ func (c *Client) updateHistory(response string) { | ||||
| 		Role:    AssistantRole, | ||||
| 		Content: response, | ||||
| 	}) | ||||
| 	_ = c.readWriter.Write(c.History) | ||||
| 	_ = c.historyStore.Write(c.History) | ||||
| } | ||||
|  | ||||
| func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int { | ||||
|   | ||||
| @@ -16,13 +16,15 @@ import ( | ||||
| ) | ||||
|  | ||||
| //go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller | ||||
| //go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store | ||||
| //go:generate mockgen -destination=historymocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history HistoryStore | ||||
| //go:generate mockgen -destination=configmocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/config ConfigStore | ||||
|  | ||||
| var ( | ||||
| 	mockCtrl   *gomock.Controller | ||||
| 	mockCaller *MockCaller | ||||
| 	mockStore  *MockStore | ||||
| 	subject    *client.Client | ||||
| 	mockCtrl         *gomock.Controller | ||||
| 	mockCaller       *MockCaller | ||||
| 	mockHistoryStore *MockHistoryStore | ||||
| 	mockConfigStore  *MockConfigStore | ||||
| 	factory          *clientFactory | ||||
| ) | ||||
|  | ||||
| func TestUnitClient(t *testing.T) { | ||||
| @@ -36,8 +38,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 		RegisterTestingT(t) | ||||
| 		mockCtrl = gomock.NewController(t) | ||||
| 		mockCaller = NewMockCaller(mockCtrl) | ||||
| 		mockStore = NewMockStore(mockCtrl) | ||||
| 		subject = client.New(mockCaller, mockStore).WithCapacity(50) | ||||
| 		mockHistoryStore = NewMockHistoryStore(mockCtrl) | ||||
| 		mockConfigStore = NewMockConfigStore(mockCtrl) | ||||
| 		factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore) | ||||
| 	}) | ||||
|  | ||||
| 	it.After(func() { | ||||
| @@ -51,63 +54,75 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 			err      error | ||||
| 		) | ||||
|  | ||||
| 		it.Before(func() { | ||||
| 			messages = createMessages(nil, query) | ||||
| 			body, err = createBody(messages, false) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 		}) | ||||
| 		type TestCase struct { | ||||
| 			description     string | ||||
| 			setupPostReturn func() ([]byte, error) | ||||
| 			postError       error | ||||
| 			expectedError   string | ||||
| 		} | ||||
|  | ||||
| 		it("throws an error when the http callout fails", func() { | ||||
| 			mockStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
| 		tests := []TestCase{ | ||||
| 			{ | ||||
| 				description:     "throws an error when the http callout fails", | ||||
| 				setupPostReturn: func() ([]byte, error) { return nil, nil }, | ||||
| 				postError:       errors.New("error message"), | ||||
| 				expectedError:   "error message", | ||||
| 			}, | ||||
| 			{ | ||||
| 				description:     "throws an error when the response is empty", | ||||
| 				setupPostReturn: func() ([]byte, error) { return nil, nil }, | ||||
| 				postError:       nil, | ||||
| 				expectedError:   "empty response", | ||||
| 			}, | ||||
| 			{ | ||||
| 				description: "throws an error when the response is a malformed json", | ||||
| 				setupPostReturn: func() ([]byte, error) { | ||||
| 					malformed := `{"invalid":"json"` // missing closing brace | ||||
| 					return []byte(malformed), nil | ||||
| 				}, | ||||
| 				postError:     nil, | ||||
| 				expectedError: "failed to decode response:", | ||||
| 			}, | ||||
| 			{ | ||||
| 				description: "throws an error when the response is missing Choices", | ||||
| 				setupPostReturn: func() ([]byte, error) { | ||||
| 					response := &types.CompletionsResponse{ | ||||
| 						ID:      "id", | ||||
| 						Object:  "object", | ||||
| 						Created: 0, | ||||
| 						Model:   "model", | ||||
| 						Choices: []types.Choice{}, | ||||
| 					} | ||||
|  | ||||
| 			errorMsg := "error message" | ||||
| 			mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, errors.New(errorMsg)) | ||||
| 					respBytes, err := json.Marshal(response) | ||||
| 					return respBytes, err | ||||
| 				}, | ||||
| 				postError:     nil, | ||||
| 				expectedError: "no responses returned", | ||||
| 			}, | ||||
| 		} | ||||
|  | ||||
| 			_, err := subject.Query(query) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 			Expect(err.Error()).To(Equal(errorMsg)) | ||||
| 		}) | ||||
| 		it("throws an error when the response is empty", func() { | ||||
| 			mockStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
| 			mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, nil) | ||||
| 		for _, tt := range tests { | ||||
| 			it(tt.description, func() { | ||||
| 				factory.withoutHistory() | ||||
| 				subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 			_, err := subject.Query(query) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 			Expect(err.Error()).To(Equal("empty response")) | ||||
| 		}) | ||||
| 		it("throws an error when the response is a malformed json", func() { | ||||
| 			mockStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
| 				messages = createMessages(nil, query) | ||||
| 				body, err = createBody(messages, subject.Model, false) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			malformed := `{"invalid":"json"` // missing closing brace | ||||
| 			mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return([]byte(malformed), nil) | ||||
| 				respBytes, err := tt.setupPostReturn() | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
| 				mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, tt.postError) | ||||
|  | ||||
| 			_, err := subject.Query(query) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 			Expect(err.Error()).Should(HavePrefix("failed to decode response:")) | ||||
| 		}) | ||||
| 		it("throws an error when the response is missing Choices", func() { | ||||
| 			mockStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
| 				_, err = subject.Query(query) | ||||
| 				Expect(err).To(HaveOccurred()) | ||||
| 				Expect(err.Error()).To(ContainSubstring(tt.expectedError)) | ||||
| 			}) | ||||
| 		} | ||||
|  | ||||
| 			response := &types.CompletionsResponse{ | ||||
| 				ID:      "id", | ||||
| 				Object:  "object", | ||||
| 				Created: 0, | ||||
| 				Model:   "model", | ||||
| 				Choices: []types.Choice{}, | ||||
| 			} | ||||
|  | ||||
| 			respBytes, err := json.Marshal(response) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 			mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, nil) | ||||
|  | ||||
| 			_, err = subject.Query(query) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 			Expect(err.Error()).To(Equal("no responses returned")) | ||||
| 		}) | ||||
| 		when("a valid http response is received", func() { | ||||
| 			testValidHTTPResponse := func(history []types.Message, expectedBody []byte) { | ||||
| 				mockStore.EXPECT().Read().Return(history, nil).Times(1) | ||||
|  | ||||
| 			testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) { | ||||
| 				const answer = "content" | ||||
|  | ||||
| 				choice := types.Choice{ | ||||
| @@ -122,7 +137,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 					ID:      "id", | ||||
| 					Object:  "object", | ||||
| 					Created: 0, | ||||
| 					Model:   client.DefaultGPTModel, | ||||
| 					Model:   subject.Model, | ||||
| 					Choices: []types.Choice{choice}, | ||||
| 				} | ||||
|  | ||||
| @@ -134,7 +149,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 				err = json.Unmarshal(expectedBody, &request) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				mockStore.EXPECT().Write(append(request.Messages, types.Message{ | ||||
| 				mockHistoryStore.EXPECT().Write(append(request.Messages, types.Message{ | ||||
| 					Role:    client.AssistantRole, | ||||
| 					Content: answer, | ||||
| 				})) | ||||
| @@ -144,8 +159,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 				Expect(result).To(Equal(answer)) | ||||
| 			} | ||||
|  | ||||
| 			it("returns the expected result for an empty history", func() { | ||||
| 				testValidHTTPResponse(nil, body) | ||||
| 			it("uses the model specified by the WithModel method instead of the default model", func() { | ||||
| 				const model = "overwritten" | ||||
|  | ||||
| 				messages = createMessages(nil, query) | ||||
| 				factory.withoutHistory() | ||||
| 				subject := factory.buildClientWithoutConfig().WithModel(model) | ||||
|  | ||||
| 				body, err = createBody(messages, model, false) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
| 				testValidHTTPResponse(subject, nil, body) | ||||
| 			}) | ||||
| 			it("uses the model specified by the configuration instead of the default model", func() { | ||||
| 				const model = "overwritten" | ||||
|  | ||||
| 				messages = createMessages(nil, query) | ||||
| 				factory.withoutHistory() | ||||
| 				subject := factory.buildClientWithConfig(types.Config{ | ||||
| 					Model: model, | ||||
| 				}) | ||||
|  | ||||
| 				body, err = createBody(messages, model, false) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
| 				testValidHTTPResponse(subject, nil, body) | ||||
| 			}) | ||||
| 			it("when WithModel is used and a configuration is present, WithModel takes precedence", func() { | ||||
| 				const model = "with-model" | ||||
|  | ||||
| 				messages = createMessages(nil, query) | ||||
| 				factory.withoutHistory() | ||||
| 				subject := factory.buildClientWithConfig(types.Config{ | ||||
| 					Model: "config-model", | ||||
| 				}).WithModel(model) | ||||
|  | ||||
| 				body, err = createBody(messages, model, false) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
| 				testValidHTTPResponse(subject, nil, body) | ||||
| 			}) | ||||
| 			it("returns the expected result for a non-empty history", func() { | ||||
| 				history := []types.Message{ | ||||
| @@ -163,10 +212,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 					}, | ||||
| 				} | ||||
| 				messages = createMessages(history, query) | ||||
| 				body, err = createBody(messages, false) | ||||
| 				factory.withHistory(history) | ||||
| 				subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 				body, err = createBody(messages, subject.Model, false) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				testValidHTTPResponse(history, body) | ||||
| 				testValidHTTPResponse(subject, history, body) | ||||
| 			}) | ||||
| 			it("truncates the history as expected", func() { | ||||
| 				history := []types.Message{ | ||||
| @@ -202,13 +254,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
|  | ||||
| 				messages = createMessages(history, query) | ||||
|  | ||||
| 				factory.withHistory(history) | ||||
| 				subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 				// messages get truncated. Index 1+2 are cut out | ||||
| 				messages = append(messages[:1], messages[3:]...) | ||||
|  | ||||
| 				body, err = createBody(messages, false) | ||||
| 				body, err = createBody(messages, subject.Model, false) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				testValidHTTPResponse(history, body) | ||||
| 				testValidHTTPResponse(subject, history, body) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| @@ -219,14 +274,13 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 			err      error | ||||
| 		) | ||||
|  | ||||
| 		it.Before(func() { | ||||
| 			messages = createMessages(nil, query) | ||||
| 			body, err = createBody(messages, true) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 		}) | ||||
|  | ||||
| 		it("throws an error when the http callout fails", func() { | ||||
| 			mockStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
| 			factory.withoutHistory() | ||||
| 			subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 			messages = createMessages(nil, query) | ||||
| 			body, err = createBody(messages, subject.Model, true) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			errorMsg := "error message" | ||||
| 			mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg)) | ||||
| @@ -238,13 +292,16 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 		when("a valid http response is received", func() { | ||||
| 			const answer = "answer" | ||||
|  | ||||
| 			testValidHTTPResponse := func(history []types.Message, expectedBody []byte) { | ||||
| 				mockStore.EXPECT().Read().Return(history, nil).Times(1) | ||||
| 			testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) { | ||||
| 				messages = createMessages(nil, query) | ||||
| 				body, err = createBody(messages, subject.Model, true) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil) | ||||
|  | ||||
| 				messages = createMessages(history, query) | ||||
|  | ||||
| 				mockStore.EXPECT().Write(append(messages, types.Message{ | ||||
| 				mockHistoryStore.EXPECT().Write(append(messages, types.Message{ | ||||
| 					Role:    client.AssistantRole, | ||||
| 					Content: answer, | ||||
| 				})) | ||||
| @@ -254,7 +311,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 			} | ||||
|  | ||||
| 			it("returns the expected result for an empty history", func() { | ||||
| 				testValidHTTPResponse(nil, body) | ||||
| 				factory.withHistory(nil) | ||||
| 				subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 				messages = createMessages(nil, query) | ||||
| 				body, err = createBody(messages, subject.Model, true) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				testValidHTTPResponse(subject, nil, body) | ||||
| 			}) | ||||
| 			it("returns the expected result for a non-empty history", func() { | ||||
| 				history := []types.Message{ | ||||
| @@ -271,16 +335,21 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 						Content: "answer 1", | ||||
| 					}, | ||||
| 				} | ||||
| 				factory.withHistory(history) | ||||
| 				subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 				messages = createMessages(history, query) | ||||
| 				body, err = createBody(messages, true) | ||||
| 				body, err = createBody(messages, subject.Model, true) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				testValidHTTPResponse(history, body) | ||||
| 				testValidHTTPResponse(subject, history, body) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 	when("ListModels()", func() { | ||||
| 		it("throws an error when the http callout fails", func() { | ||||
| 			subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 			errorMsg := "error message" | ||||
| 			mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg)) | ||||
|  | ||||
| @@ -289,6 +358,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 			Expect(err.Error()).To(Equal(errorMsg)) | ||||
| 		}) | ||||
| 		it("throws an error when the response is empty", func() { | ||||
| 			subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 			mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil) | ||||
|  | ||||
| 			_, err := subject.ListModels() | ||||
| @@ -296,6 +367,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 			Expect(err.Error()).To(Equal("empty response")) | ||||
| 		}) | ||||
| 		it("throws an error when the response is a malformed json", func() { | ||||
| 			subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 			malformed := `{"invalid":"json"` // missing closing brace | ||||
| 			mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil) | ||||
|  | ||||
| @@ -304,6 +377,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 			Expect(err.Error()).Should(HavePrefix("failed to decode response:")) | ||||
| 		}) | ||||
| 		it("filters gpt models as expected", func() { | ||||
| 			subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 			response, err := utils.FileToBytes("models.json") | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| @@ -319,8 +394,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 	}) | ||||
| 	when("ProvideContext()", func() { | ||||
| 		it("updates the history with the provided context", func() { | ||||
| 			subject := factory.buildClientWithoutConfig() | ||||
|  | ||||
| 			context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake." | ||||
| 			mockStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
| 			mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
|  | ||||
| 			subject.ProvideContext(context) | ||||
|  | ||||
| 			Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context | ||||
| @@ -336,9 +414,9 @@ func testClient(t *testing.T, when spec.G, it spec.S) { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func createBody(messages []types.Message, stream bool) ([]byte, error) { | ||||
| func createBody(messages []types.Message, model string, stream bool) ([]byte, error) { | ||||
| 	req := types.CompletionsRequest{ | ||||
| 		Model:    client.DefaultGPTModel, | ||||
| 		Model:    model, | ||||
| 		Messages: messages, | ||||
| 		Stream:   stream, | ||||
| 	} | ||||
| @@ -365,3 +443,35 @@ func createMessages(history []types.Message, query string) []types.Message { | ||||
|  | ||||
| 	return messages | ||||
| } | ||||
|  | ||||
| type clientFactory struct { | ||||
| 	mockCaller       *MockCaller | ||||
| 	mockConfigStore  *MockConfigStore | ||||
| 	mockHistoryStore *MockHistoryStore | ||||
| } | ||||
|  | ||||
| func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory { | ||||
| 	return &clientFactory{ | ||||
| 		mockCaller:       mc, | ||||
| 		mockConfigStore:  mcs, | ||||
| 		mockHistoryStore: mhs, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (f *clientFactory) buildClientWithoutConfig() *client.Client { | ||||
| 	f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1) | ||||
| 	return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50) | ||||
| } | ||||
|  | ||||
| func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client { | ||||
| 	f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1) | ||||
| 	return client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore).WithCapacity(50) | ||||
| } | ||||
|  | ||||
| func (f *clientFactory) withoutHistory() { | ||||
| 	f.mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1) | ||||
| } | ||||
|  | ||||
| func (f *clientFactory) withHistory(history []types.Message) { | ||||
| 	f.mockHistoryStore.EXPECT().Read().Return(history, nil).Times(1) | ||||
| } | ||||
|   | ||||
							
								
								
									
										64
									
								
								client/configmocks_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								client/configmocks_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| // Code generated by MockGen. DO NOT EDIT. | ||||
| // Source: github.com/kardolus/chatgpt-cli/config (interfaces: ConfigStore) | ||||
|  | ||||
| // Package client_test is a generated GoMock package. | ||||
| package client_test | ||||
|  | ||||
| import ( | ||||
| 	reflect "reflect" | ||||
|  | ||||
| 	gomock "github.com/golang/mock/gomock" | ||||
| 	types "github.com/kardolus/chatgpt-cli/types" | ||||
| ) | ||||
|  | ||||
| // MockConfigStore is a mock of ConfigStore interface. | ||||
| type MockConfigStore struct { | ||||
| 	ctrl     *gomock.Controller | ||||
| 	recorder *MockConfigStoreMockRecorder | ||||
| } | ||||
|  | ||||
| // MockConfigStoreMockRecorder is the mock recorder for MockConfigStore. | ||||
| type MockConfigStoreMockRecorder struct { | ||||
| 	mock *MockConfigStore | ||||
| } | ||||
|  | ||||
| // NewMockConfigStore creates a new mock instance. | ||||
| func NewMockConfigStore(ctrl *gomock.Controller) *MockConfigStore { | ||||
| 	mock := &MockConfigStore{ctrl: ctrl} | ||||
| 	mock.recorder = &MockConfigStoreMockRecorder{mock} | ||||
| 	return mock | ||||
| } | ||||
|  | ||||
| // EXPECT returns an object that allows the caller to indicate expected use. | ||||
| func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder { | ||||
| 	return m.recorder | ||||
| } | ||||
|  | ||||
| // Read mocks base method. | ||||
| func (m *MockConfigStore) Read() (types.Config, error) { | ||||
| 	m.ctrl.T.Helper() | ||||
| 	ret := m.ctrl.Call(m, "Read") | ||||
| 	ret0, _ := ret[0].(types.Config) | ||||
| 	ret1, _ := ret[1].(error) | ||||
| 	return ret0, ret1 | ||||
| } | ||||
|  | ||||
| // Read indicates an expected call of Read. | ||||
| func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call { | ||||
| 	mr.mock.ctrl.T.Helper() | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read)) | ||||
| } | ||||
|  | ||||
| // Write mocks base method. | ||||
| func (m *MockConfigStore) Write(arg0 types.Config) error { | ||||
| 	m.ctrl.T.Helper() | ||||
| 	ret := m.ctrl.Call(m, "Write", arg0) | ||||
| 	ret0, _ := ret[0].(error) | ||||
| 	return ret0 | ||||
| } | ||||
|  | ||||
| // Write indicates an expected call of Write. | ||||
| func (mr *MockConfigStoreMockRecorder) Write(arg0 interface{}) *gomock.Call { | ||||
| 	mr.mock.ctrl.T.Helper() | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConfigStore)(nil).Write), arg0) | ||||
| } | ||||
| @@ -1,5 +1,5 @@ | ||||
| // Code generated by MockGen. DO NOT EDIT. | ||||
| // Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store) | ||||
| // Source: github.com/kardolus/chatgpt-cli/history (interfaces: HistoryStore) | ||||
| 
 | ||||
| // Package client_test is a generated GoMock package. | ||||
| package client_test | ||||
| @@ -11,31 +11,31 @@ import ( | ||||
| 	types "github.com/kardolus/chatgpt-cli/types" | ||||
| ) | ||||
| 
 | ||||
| // MockStore is a mock of Store interface. | ||||
| type MockStore struct { | ||||
| // MockHistoryStore is a mock of HistoryStore interface. | ||||
| type MockHistoryStore struct { | ||||
| 	ctrl     *gomock.Controller | ||||
| 	recorder *MockStoreMockRecorder | ||||
| 	recorder *MockHistoryStoreMockRecorder | ||||
| } | ||||
| 
 | ||||
| // MockStoreMockRecorder is the mock recorder for MockStore. | ||||
| type MockStoreMockRecorder struct { | ||||
| 	mock *MockStore | ||||
| // MockHistoryStoreMockRecorder is the mock recorder for MockHistoryStore. | ||||
| type MockHistoryStoreMockRecorder struct { | ||||
| 	mock *MockHistoryStore | ||||
| } | ||||
| 
 | ||||
| // NewMockStore creates a new mock instance. | ||||
| func NewMockStore(ctrl *gomock.Controller) *MockStore { | ||||
| 	mock := &MockStore{ctrl: ctrl} | ||||
| 	mock.recorder = &MockStoreMockRecorder{mock} | ||||
| // NewMockHistoryStore creates a new mock instance. | ||||
| func NewMockHistoryStore(ctrl *gomock.Controller) *MockHistoryStore { | ||||
| 	mock := &MockHistoryStore{ctrl: ctrl} | ||||
| 	mock.recorder = &MockHistoryStoreMockRecorder{mock} | ||||
| 	return mock | ||||
| } | ||||
| 
 | ||||
| // EXPECT returns an object that allows the caller to indicate expected use. | ||||
| func (m *MockStore) EXPECT() *MockStoreMockRecorder { | ||||
| func (m *MockHistoryStore) EXPECT() *MockHistoryStoreMockRecorder { | ||||
| 	return m.recorder | ||||
| } | ||||
| 
 | ||||
| // Delete mocks base method. | ||||
| func (m *MockStore) Delete() error { | ||||
| func (m *MockHistoryStore) Delete() error { | ||||
| 	m.ctrl.T.Helper() | ||||
| 	ret := m.ctrl.Call(m, "Delete") | ||||
| 	ret0, _ := ret[0].(error) | ||||
| @@ -43,13 +43,13 @@ func (m *MockStore) Delete() error { | ||||
| } | ||||
| 
 | ||||
| // Delete indicates an expected call of Delete. | ||||
| func (mr *MockStoreMockRecorder) Delete() *gomock.Call { | ||||
| func (mr *MockHistoryStoreMockRecorder) Delete() *gomock.Call { | ||||
| 	mr.mock.ctrl.T.Helper() | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete)) | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockHistoryStore)(nil).Delete)) | ||||
| } | ||||
| 
 | ||||
| // Read mocks base method. | ||||
| func (m *MockStore) Read() ([]types.Message, error) { | ||||
| func (m *MockHistoryStore) Read() ([]types.Message, error) { | ||||
| 	m.ctrl.T.Helper() | ||||
| 	ret := m.ctrl.Call(m, "Read") | ||||
| 	ret0, _ := ret[0].([]types.Message) | ||||
| @@ -58,13 +58,13 @@ func (m *MockStore) Read() ([]types.Message, error) { | ||||
| } | ||||
| 
 | ||||
| // Read indicates an expected call of Read. | ||||
| func (mr *MockStoreMockRecorder) Read() *gomock.Call { | ||||
| func (mr *MockHistoryStoreMockRecorder) Read() *gomock.Call { | ||||
| 	mr.mock.ctrl.T.Helper() | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read)) | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockHistoryStore)(nil).Read)) | ||||
| } | ||||
| 
 | ||||
| // Write mocks base method. | ||||
| func (m *MockStore) Write(arg0 []types.Message) error { | ||||
| func (m *MockHistoryStore) Write(arg0 []types.Message) error { | ||||
| 	m.ctrl.T.Helper() | ||||
| 	ret := m.ctrl.Call(m, "Write", arg0) | ||||
| 	ret0, _ := ret[0].(error) | ||||
| @@ -72,7 +72,7 @@ func (m *MockStore) Write(arg0 []types.Message) error { | ||||
| } | ||||
| 
 | ||||
| // Write indicates an expected call of Write. | ||||
| func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call { | ||||
| func (mr *MockHistoryStoreMockRecorder) Write(arg0 interface{}) *gomock.Call { | ||||
| 	mr.mock.ctrl.T.Helper() | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0) | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockHistoryStore)(nil).Write), arg0) | ||||
| } | ||||
| @@ -5,6 +5,8 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/kardolus/chatgpt-cli/client" | ||||
| 	"github.com/kardolus/chatgpt-cli/config" | ||||
| 	"github.com/kardolus/chatgpt-cli/configmanager" | ||||
| 	"github.com/kardolus/chatgpt-cli/history" | ||||
| 	"github.com/kardolus/chatgpt-cli/http" | ||||
| 	"github.com/spf13/cobra" | ||||
| @@ -43,6 +45,7 @@ func main() { | ||||
| 	rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information") | ||||
| 	rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models") | ||||
| 	rootCmd.PersistentFlags().StringVarP(&modelName, "model", "m", "", "Use a custom GPT model by specifying the model name") | ||||
| 	rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name") | ||||
|  | ||||
| 	viper.AutomaticEnv() | ||||
|  | ||||
| @@ -57,7 +60,7 @@ func run(cmd *cobra.Command, args []string) error { | ||||
| 	if secret == "" { | ||||
| 		return errors.New("missing environment variable: " + secretEnv) | ||||
| 	} | ||||
| 	client := client.New(http.New().WithSecret(secret), history.New()) | ||||
| 	client := client.New(http.New().WithSecret(secret), config.New(), history.New()) | ||||
|  | ||||
| 	if modelName != "" { | ||||
| 		client = client.WithModel(modelName) | ||||
| @@ -78,6 +81,14 @@ func run(cmd *cobra.Command, args []string) error { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if cmd.Flag("set-model").Changed { | ||||
| 		if err := configmanager.New(config.New()).WriteModel(modelName); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		fmt.Println("Model successfully updated to", modelName) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if clearHistory { | ||||
| 		historyHandler := history.New() | ||||
| 		err := historyHandler.Delete() | ||||
|   | ||||
							
								
								
									
										69
									
								
								config/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								config/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| package config | ||||
|  | ||||
| import ( | ||||
| 	"github.com/kardolus/chatgpt-cli/types" | ||||
| 	"gopkg.in/yaml.v3" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| ) | ||||
|  | ||||
| type ConfigStore interface { | ||||
| 	Read() (types.Config, error) | ||||
| 	Write(types.Config) error | ||||
| } | ||||
|  | ||||
| // Ensure FileIO implements ConfigStore interface | ||||
| var _ ConfigStore = &FileIO{} | ||||
|  | ||||
| type FileIO struct { | ||||
| 	configFilePath string | ||||
| } | ||||
|  | ||||
| func New() *FileIO { | ||||
| 	path, _ := getPath() | ||||
| 	return &FileIO{ | ||||
| 		configFilePath: path, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (f *FileIO) WithFilePath(configFilePath string) *FileIO { | ||||
| 	f.configFilePath = configFilePath | ||||
| 	return f | ||||
| } | ||||
|  | ||||
| func (f *FileIO) Read() (types.Config, error) { | ||||
| 	return parseFile(f.configFilePath) | ||||
| } | ||||
|  | ||||
| func (f *FileIO) Write(config types.Config) error { | ||||
| 	data, err := yaml.Marshal(config) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return os.WriteFile(f.configFilePath, data, 0644) | ||||
| } | ||||
|  | ||||
| func getPath() (string, error) { | ||||
| 	homeDir, err := os.UserHomeDir() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return filepath.Join(homeDir, ".chatgpt-cli", "config.yaml"), nil | ||||
| } | ||||
|  | ||||
| func parseFile(fileName string) (types.Config, error) { | ||||
| 	var result types.Config | ||||
|  | ||||
| 	buf, err := os.ReadFile(fileName) | ||||
| 	if err != nil { | ||||
| 		return types.Config{}, err | ||||
| 	} | ||||
|  | ||||
| 	if err := yaml.Unmarshal(buf, &result); err != nil { | ||||
| 		return types.Config{}, err | ||||
| 	} | ||||
|  | ||||
| 	return result, nil | ||||
| } | ||||
							
								
								
									
										27
									
								
								configmanager/configmanager.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								configmanager/configmanager.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| package configmanager | ||||
|  | ||||
| import ( | ||||
| 	"github.com/kardolus/chatgpt-cli/config" | ||||
| 	"github.com/kardolus/chatgpt-cli/types" | ||||
| ) | ||||
|  | ||||
| type ConfigManager struct { | ||||
| 	configStore config.ConfigStore | ||||
| } | ||||
|  | ||||
| func New(cs config.ConfigStore) *ConfigManager { | ||||
| 	return &ConfigManager{configStore: cs} | ||||
| } | ||||
|  | ||||
| func (c *ConfigManager) ReadModel() (string, error) { | ||||
| 	conf, err := c.configStore.Read() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return conf.Model, nil | ||||
| } | ||||
|  | ||||
| func (c *ConfigManager) WriteModel(model string) error { | ||||
| 	return c.configStore.Write(types.Config{Model: model}) | ||||
| } | ||||
							
								
								
									
										62
									
								
								configmanager/configmanager_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								configmanager/configmanager_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| package configmanager_test | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"github.com/golang/mock/gomock" | ||||
| 	"github.com/kardolus/chatgpt-cli/configmanager" | ||||
| 	"github.com/kardolus/chatgpt-cli/types" | ||||
| 	. "github.com/onsi/gomega" | ||||
| 	"github.com/sclevine/spec" | ||||
| 	"github.com/sclevine/spec/report" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| //go:generate mockgen -destination=configmocks_test.go -package=configmanager_test github.com/kardolus/chatgpt-cli/config ConfigStore | ||||
|  | ||||
| func TestUnitConfigManager(t *testing.T) { | ||||
| 	spec.Run(t, "Testing the Config Manager", testConfig, spec.Report(report.Terminal{})) | ||||
| } | ||||
|  | ||||
| func testConfig(t *testing.T, when spec.G, it spec.S) { | ||||
| 	var ( | ||||
| 		mockCtrl        *gomock.Controller | ||||
| 		mockConfigStore *MockConfigStore | ||||
| 		subject         *configmanager.ConfigManager | ||||
| 	) | ||||
|  | ||||
| 	it.Before(func() { | ||||
| 		RegisterTestingT(t) | ||||
| 		mockCtrl = gomock.NewController(t) | ||||
| 		mockConfigStore = NewMockConfigStore(mockCtrl) | ||||
| 		subject = configmanager.New(mockConfigStore) | ||||
| 	}) | ||||
|  | ||||
| 	it.After(func() { | ||||
| 		mockCtrl.Finish() | ||||
| 	}) | ||||
|  | ||||
| 	when("ReadModel()", func() { | ||||
| 		it("throws an error when the config file does not exist", func() { | ||||
| 			expectedErrorMsg := "file not found" | ||||
| 			mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New(expectedErrorMsg)).Times(1) | ||||
| 			_, err := subject.ReadModel() | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 			Expect(err.Error()).To(Equal(expectedErrorMsg)) | ||||
| 		}) | ||||
| 		it("parses a config file as expected", func() { | ||||
| 			modelName := "the-model" | ||||
| 			mockConfigStore.EXPECT().Read().Return(types.Config{Model: modelName}, nil).Times(1) | ||||
|  | ||||
| 			result, err := subject.ReadModel() | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 			Expect(result).To(Equal(modelName)) | ||||
| 		}) | ||||
| 	}) | ||||
| 	when("WriteModel()", func() { | ||||
| 		it("writes the expected config file", func() { | ||||
| 			modelName := "the-model" | ||||
| 			mockConfigStore.EXPECT().Write(types.Config{Model: modelName}).Times(1) | ||||
| 			Expect(subject.WriteModel(modelName)).To(Succeed()) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										64
									
								
								configmanager/configmocks_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								configmanager/configmocks_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| // Code generated by MockGen. DO NOT EDIT. | ||||
| // Source: github.com/kardolus/chatgpt-cli/config (interfaces: ConfigStore) | ||||
|  | ||||
| // Package configmanager_test is a generated GoMock package. | ||||
| package configmanager_test | ||||
|  | ||||
| import ( | ||||
| 	reflect "reflect" | ||||
|  | ||||
| 	gomock "github.com/golang/mock/gomock" | ||||
| 	types "github.com/kardolus/chatgpt-cli/types" | ||||
| ) | ||||
|  | ||||
| // MockConfigStore is a mock of ConfigStore interface. | ||||
| type MockConfigStore struct { | ||||
| 	ctrl     *gomock.Controller | ||||
| 	recorder *MockConfigStoreMockRecorder | ||||
| } | ||||
|  | ||||
| // MockConfigStoreMockRecorder is the mock recorder for MockConfigStore. | ||||
| type MockConfigStoreMockRecorder struct { | ||||
| 	mock *MockConfigStore | ||||
| } | ||||
|  | ||||
| // NewMockConfigStore creates a new mock instance. | ||||
| func NewMockConfigStore(ctrl *gomock.Controller) *MockConfigStore { | ||||
| 	mock := &MockConfigStore{ctrl: ctrl} | ||||
| 	mock.recorder = &MockConfigStoreMockRecorder{mock} | ||||
| 	return mock | ||||
| } | ||||
|  | ||||
| // EXPECT returns an object that allows the caller to indicate expected use. | ||||
| func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder { | ||||
| 	return m.recorder | ||||
| } | ||||
|  | ||||
| // Read mocks base method. | ||||
| func (m *MockConfigStore) Read() (types.Config, error) { | ||||
| 	m.ctrl.T.Helper() | ||||
| 	ret := m.ctrl.Call(m, "Read") | ||||
| 	ret0, _ := ret[0].(types.Config) | ||||
| 	ret1, _ := ret[1].(error) | ||||
| 	return ret0, ret1 | ||||
| } | ||||
|  | ||||
| // Read indicates an expected call of Read. | ||||
| func (mr *MockConfigStoreMockRecorder) Read() *gomock.Call { | ||||
| 	mr.mock.ctrl.T.Helper() | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConfigStore)(nil).Read)) | ||||
| } | ||||
|  | ||||
| // Write mocks base method. | ||||
| func (m *MockConfigStore) Write(arg0 types.Config) error { | ||||
| 	m.ctrl.T.Helper() | ||||
| 	ret := m.ctrl.Call(m, "Write", arg0) | ||||
| 	ret0, _ := ret[0].(error) | ||||
| 	return ret0 | ||||
| } | ||||
|  | ||||
| // Write indicates an expected call of Write. | ||||
| func (mr *MockConfigStoreMockRecorder) Write(arg0 interface{}) *gomock.Call { | ||||
| 	mr.mock.ctrl.T.Helper() | ||||
| 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConfigStore)(nil).Write), arg0) | ||||
| } | ||||
| @@ -3,19 +3,18 @@ package history | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/kardolus/chatgpt-cli/types" | ||||
| 	"io/ioutil" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| ) | ||||
|  | ||||
| type Store interface { | ||||
| type HistoryStore interface { | ||||
| 	Delete() error | ||||
| 	Read() ([]types.Message, error) | ||||
| 	Write([]types.Message) error | ||||
| } | ||||
|  | ||||
| // Ensure RestCaller implements Caller interface | ||||
| var _ Store = &FileIO{} | ||||
| // Ensure FileIO implements the HistoryStore interface | ||||
| var _ HistoryStore = &FileIO{} | ||||
|  | ||||
| type FileIO struct { | ||||
| 	historyFilePath string | ||||
| @@ -28,7 +27,7 @@ func New() *FileIO { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (f *FileIO) WithHistory(historyFilePath string) *FileIO { | ||||
| func (f *FileIO) WithFilePath(historyFilePath string) *FileIO { | ||||
| 	f.historyFilePath = historyFilePath | ||||
| 	return f | ||||
| } | ||||
| @@ -51,7 +50,7 @@ func (f *FileIO) Write(messages []types.Message) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return ioutil.WriteFile(f.historyFilePath, data, 0644) | ||||
| 	return os.WriteFile(f.historyFilePath, data, 0644) | ||||
| } | ||||
|  | ||||
| func getPath() (string, error) { | ||||
| @@ -66,7 +65,7 @@ func getPath() (string, error) { | ||||
| func parseFile(fileName string) ([]types.Message, error) { | ||||
| 	var result []types.Message | ||||
|  | ||||
| 	buf, err := ioutil.ReadFile(fileName) | ||||
| 	buf, err := os.ReadFile(fileName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
| @@ -1,11 +1,11 @@ | ||||
| package integration_test | ||||
|  | ||||
| import ( | ||||
| 	"github.com/kardolus/chatgpt-cli/config" | ||||
| 	"github.com/kardolus/chatgpt-cli/history" | ||||
| 	"github.com/kardolus/chatgpt-cli/types" | ||||
| 	"github.com/sclevine/spec" | ||||
| 	"github.com/sclevine/spec/report" | ||||
| 	"io/ioutil" | ||||
| 	"os" | ||||
| 	"testing" | ||||
|  | ||||
| @@ -21,7 +21,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { | ||||
| 		RegisterTestingT(t) | ||||
| 	}) | ||||
|  | ||||
| 	when("Read, Write and Delete", func() { | ||||
| 	when("Read, Write and Delete History", func() { | ||||
| 		var ( | ||||
| 			tmpDir   string | ||||
| 			tmpFile  *os.File | ||||
| @@ -31,15 +31,15 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { | ||||
| 		) | ||||
|  | ||||
| 		it.Before(func() { | ||||
| 			tmpDir, err = ioutil.TempDir("", "chatgpt-cli-test") | ||||
| 			tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test") | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			tmpFile, err = ioutil.TempFile(tmpDir, "history.json") | ||||
| 			tmpFile, err = os.CreateTemp(tmpDir, "history.json") | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			tmpFile.Close() | ||||
| 			Expect(tmpFile.Close()).To(Succeed()) | ||||
|  | ||||
| 			fileIO = history.New().WithHistory(tmpFile.Name()) | ||||
| 			fileIO = history.New().WithFilePath(tmpFile.Name()) | ||||
|  | ||||
| 			messages = []types.Message{ | ||||
| 				{ | ||||
| @@ -54,7 +54,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { | ||||
| 		}) | ||||
|  | ||||
| 		it.After(func() { | ||||
| 			os.RemoveAll(tmpDir) | ||||
| 			Expect(os.RemoveAll(tmpDir)).To(Succeed()) | ||||
| 		}) | ||||
|  | ||||
| 		it("writes the messages to the file", func() { | ||||
| @@ -79,4 +79,62 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { | ||||
| 			Expect(os.IsNotExist(err)).To(BeTrue()) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	when("Read, Write Config", func() { | ||||
| 		var ( | ||||
| 			tmpDir     string | ||||
| 			tmpFile    *os.File | ||||
| 			configIO   *config.FileIO | ||||
| 			testConfig types.Config | ||||
| 			err        error | ||||
| 		) | ||||
|  | ||||
| 		it.Before(func() { | ||||
| 			tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test") | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			tmpFile, err = os.CreateTemp(tmpDir, "config.yaml") | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			Expect(tmpFile.Close()).To(Succeed()) | ||||
|  | ||||
| 			configIO = config.New().WithFilePath(tmpFile.Name()) | ||||
|  | ||||
| 			testConfig = types.Config{ | ||||
| 				Model: "test-model", | ||||
| 			} | ||||
| 		}) | ||||
|  | ||||
| 		it.After(func() { | ||||
| 			Expect(os.RemoveAll(tmpDir)).To(Succeed()) | ||||
| 		}) | ||||
|  | ||||
| 		it("writes the config to the file", func() { | ||||
| 			err = configIO.Write(testConfig) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 		}) | ||||
|  | ||||
| 		it("reads the config from the file", func() { | ||||
| 			err = configIO.Write(testConfig) // need to write before reading | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			readConfig, err := configIO.Read() | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 			Expect(readConfig).To(Equal(testConfig)) | ||||
| 		}) | ||||
|  | ||||
| 		// Since we don't have a Delete method in the config, we will test if we can overwrite the configuration. | ||||
| 		it("overwrites the existing config", func() { | ||||
| 			newConfig := types.Config{ | ||||
| 				Model: "new-model", | ||||
| 			} | ||||
| 			err = configIO.Write(newConfig) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 			readConfig, err := configIO.Read() | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 			Expect(readConfig).To(Equal(newConfig)) | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| } | ||||
|   | ||||
							
								
								
									
										5
									
								
								types/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								types/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package types | ||||
|  | ||||
| type Config struct { | ||||
| 	Model string `yaml:"model"` | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 kardolus
					kardolus