Upgrade to a more robust configuration

- ~./chatgpt-cli/config.yaml takes precedence over default values
- environment variables take precedence over config.yaml
- the name as provided in the config is the prefix for the environment
  variables (ie. openai)
- remove the --models flag
- add an omit_history configuration option
This commit is contained in:
kardolus
2023-06-19 09:07:40 -04:00
parent 7dc341a72c
commit 6c9de4ab55
13 changed files with 673 additions and 137 deletions

View File

@@ -133,15 +133,7 @@ Then, use the pipe feature to provide this context to ChatGPT:
cat context.txt | chatgpt "What kind of toy would Kya enjoy?" cat context.txt | chatgpt "What kind of toy would Kya enjoy?"
``` ```
6. To set a specific model, use the `--set-model` flag followed by the model name: 6. To list all available models, use the -l or --list-models flag:
```shell
chatgpt --set-model gpt-3.5-turbo-0301
```
Remember to check that the model exists in the OpenAI model list before setting it.
7. To list all available models, use the -l or --list-models flag:
```shell ```shell
chatgpt --list-models chatgpt --list-models
@@ -149,54 +141,67 @@ chatgpt --list-models
## Configuration ## Configuration
The chatGPT CLI uses a two-level configuration system. The default values are: The ChatGPT CLI adopts a three-tier configuration strategy, with different levels of precedence assigned to default
values, the `config.yaml` file, and environment variables, in that respective order.
The default configuration:
```yaml ```yaml
name: openai
api_key:
model: gpt-3.5-turbo model: gpt-3.5-turbo
max_tokens: 4096 max_tokens: 4096
url: https://api.openai.com url: https://api.openai.com
completions_path: /v1/chat/completions completions_path: /v1/chat/completions
models_path: /v1/models models_path: /v1/models
omit_history: false
``` ```
These default settings can be overwritten by user-defined configuration options. The user configuration file These defaults can be overridden by providing your own values in the user configuration file,
is `.chatgpt-cli/config.yaml` and is expected to be in the user's home directory. named `.chatgpt-cli/config.yaml`, located in your home directory.
The user configuration file follows the same structure as the default configuration file. Here is an example of how to The structure of the user configuration file mirrors that of the default configuration. For instance, to override
override the `model` and `max_tokens` values: the `model` and `max_tokens` parameters, your file might look like this:
```yaml ```yaml
model: gpt-3.5-turbo-16k model: gpt-3.5-turbo-16k
max_tokens: 8192 max_tokens: 8192
``` ```
In this example, the `model` is changed to `gpt-3.5-turbo-16k`, and `max_tokens` is set to `8192`. Other options such This alters the `model` to `gpt-3.5-turbo-16k` and adjusts `max_tokens` to `8192`. All other options, such as `url`
as `url`, `completions_path`, and `models_path` can be adjusted in the same manner if needed. , `completions_path`, and `models_path`, can similarly be modified. If the user configuration file cannot be accessed or
is missing, the application will resort to the default configuration.
Note: If the user configuration file is not found or cannot be read for any reason, the application will fall back to Another way to adjust values without manually editing the configuration file is by using environment variables.
the default configuration. The `name` attribute forms the prefix for these variables. As an example, the `model` can be modified using
the `OPENAI_MODEL` environment variable. Similarly, to disable history during the execution of a command, use:
As a more immediate and flexible alternative to changing the configuration file manually, the CLI offers command-line
flags for overwriting specific configuration values. For instance, the `model` can be changed using the `--model`
flag. This is particularly useful for temporary adjustments or testing different configurations.
```shell ```shell
chatgpt --model gpt-3.5-turbo-16k What are some fun things to do in Red Hook? OPENAI_OMIT_HISTORY=true chatgpt tell me a joke
``` ```
This command will temporarily overwrite the `model` value for the duration of the current command. We're currently This approach is especially beneficial for temporary changes or for testing varying configurations.
working on adding similar flags for other configuration values, which will allow you to adjust most aspects of the
configuration directly from the command line.
In addition, the `--config` or `-c` flag can be used to display the current configuration. This allows users to easily Moreover, you can use the `--config` or `-c` flag to view the present configuration. This handy feature allows users to
check their current settings without having to manually open and read the configuration files. swiftly verify their current settings without the need to manually inspect the configuration files.
```shell ```shell
chatgpt --config chatgpt --config
``` ```
This command will display the current configuration including any overrides applied by command line flags or the user Executing this command will display the active configuration, including any overrides instituted by environment
configuration file. variables or the user configuration file.
To facilitate convenient adjustments, the ChatGPT CLI provides two flags for swiftly modifying the `model`
and `max_tokens` parameters in your user configured `config.yaml`. These flags are `--set-model` and `--set-max-tokens`.
For instance, to update the model, use the following command:
```shell
chatgpt --set-model gpt-3.5-turbo-16k
```
This feature allows for rapid changes to key configuration parameters, optimizing your experience with the ChatGPT CLI.
## Development ## Development

View File

@@ -62,3 +62,15 @@ func (mr *MockCallerMockRecorder) Post(arg0, arg1, arg2 interface{}) *gomock.Cal
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1, arg2)
} }
// SetAPIKey mocks base method.
func (m *MockCaller) SetAPIKey(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAPIKey", arg0)
}
// SetAPIKey indicates an expected call of SetAPIKey.
func (mr *MockCallerMockRecorder) SetAPIKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAPIKey", reflect.TypeOf((*MockCaller)(nil).SetAPIKey), arg0)
}

View File

@@ -30,12 +30,21 @@ type Client struct {
historyStore history.HistoryStore historyStore history.HistoryStore
} }
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client { func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) (*Client, error) {
cm := configmanager.New(cs).WithEnvironment()
configuration := cm.Config
if configuration.APIKey == "" {
return nil, errors.New("missing environment variable: " + cm.APIKeyEnvVarName())
}
caller.SetAPIKey(configuration.APIKey)
return &Client{ return &Client{
Config: configmanager.New(cs).Config, Config: configuration,
caller: caller, caller: caller,
historyStore: hs, historyStore: hs,
} }, nil
} }
func (c *Client) WithCapacity(capacity int) *Client { func (c *Client) WithCapacity(capacity int) *Client {
@@ -43,11 +52,6 @@ func (c *Client) WithCapacity(capacity int) *Client {
return c return c
} }
func (c *Client) WithModel(model string) *Client {
c.Config.Model = model
return c
}
func (c *Client) WithServiceURL(url string) *Client { func (c *Client) WithServiceURL(url string) *Client {
c.Config.URL = url c.Config.URL = url
return c return c
@@ -169,7 +173,10 @@ func (c *Client) initHistory() {
return return
} }
c.History, _ = c.historyStore.Read() if !c.Config.OmitHistory {
c.History, _ = c.historyStore.Read()
}
if len(c.History) == 0 { if len(c.History) == 0 {
c.History = []types.Message{{ c.History = []types.Message{{
Role: SystemRole, Role: SystemRole,
@@ -237,7 +244,10 @@ func (c *Client) updateHistory(response string) {
Role: AssistantRole, Role: AssistantRole,
Content: response, Content: response,
}) })
_ = c.historyStore.Write(c.History)
if !c.Config.OmitHistory {
_ = c.historyStore.Write(c.History)
}
} }
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int { func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {

View File

@@ -8,6 +8,8 @@ import (
"github.com/kardolus/chatgpt-cli/client" "github.com/kardolus/chatgpt-cli/client"
"github.com/kardolus/chatgpt-cli/types" "github.com/kardolus/chatgpt-cli/types"
"github.com/kardolus/chatgpt-cli/utils" "github.com/kardolus/chatgpt-cli/utils"
"os"
"strings"
"testing" "testing"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@@ -22,9 +24,11 @@ import (
const ( const (
defaultMaxTokens = 4096 defaultMaxTokens = 4096
defaultURL = "https://api.openai.com" defaultURL = "https://api.openai.com"
defaultName = "default-name"
defaultModel = "gpt-3.5-turbo" defaultModel = "gpt-3.5-turbo"
defaultCompletionsPath = "/v1/chat/completions" defaultCompletionsPath = "/v1/chat/completions"
defaultModelsPath = "/v1/models" defaultModelsPath = "/v1/models"
envApiKey = "api-key"
) )
var ( var (
@@ -33,6 +37,7 @@ var (
mockHistoryStore *MockHistoryStore mockHistoryStore *MockHistoryStore
mockConfigStore *MockConfigStore mockConfigStore *MockConfigStore
factory *clientFactory factory *clientFactory
apiKeyEnvVar string
) )
func TestUnitClient(t *testing.T) { func TestUnitClient(t *testing.T) {
@@ -50,12 +55,28 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockConfigStore = NewMockConfigStore(mockCtrl) mockConfigStore = NewMockConfigStore(mockCtrl)
factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore) factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore)
apiKeyEnvVar = strings.ToUpper(defaultName) + "_API_KEY"
Expect(os.Setenv(apiKeyEnvVar, envApiKey)).To(Succeed())
}) })
it.After(func() { it.After(func() {
mockCtrl.Finish() mockCtrl.Finish()
}) })
when("New()", func() {
it("fails to construct when the API key is missing", func() {
Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())
mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
_, err := client.New(mockCaller, mockConfigStore, mockHistoryStore)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(apiKeyEnvVar))
})
})
when("Query()", func() { when("Query()", func() {
var ( var (
body []byte body []byte
@@ -131,7 +152,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
} }
when("a valid http response is received", func() { when("a valid http response is received", func() {
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte) { testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte, omitHistory bool) {
const answer = "content" const answer = "content"
choice := types.Choice{ choice := types.Choice{
@@ -158,27 +179,18 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
err = json.Unmarshal(expectedBody, &request) err = json.Unmarshal(expectedBody, &request)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
mockHistoryStore.EXPECT().Write(append(request.Messages, types.Message{ if !omitHistory {
Role: client.AssistantRole, mockHistoryStore.EXPECT().Write(append(request.Messages, types.Message{
Content: answer, Role: client.AssistantRole,
})) Content: answer,
}))
}
result, err := subject.Query(query) result, err := subject.Query(query)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(answer)) Expect(result).To(Equal(answer))
} }
it("uses the model specified by the WithModel method instead of the default model", func() {
const model = "overwritten"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithoutConfig().WithModel(model)
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
})
it("uses the model specified by the configuration instead of the default model", func() { it("uses the model specified by the configuration instead of the default model", func() {
const model = "overwritten" const model = "overwritten"
@@ -190,20 +202,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
body, err = createBody(messages, model, false) body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body) testValidHTTPResponse(subject, nil, body, false)
})
it("when WithModel is used and a configuration is present, WithModel takes precedence", func() {
const model = "with-model"
messages = createMessages(nil, query)
factory.withoutHistory()
subject := factory.buildClientWithConfig(types.Config{
Model: "config-model",
}).WithModel(model)
body, err = createBody(messages, model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body)
}) })
it("returns the expected result for a non-empty history", func() { it("returns the expected result for a non-empty history", func() {
history := []types.Message{ history := []types.Message{
@@ -227,7 +226,25 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
body, err = createBody(messages, subject.Config.Model, false) body, err = createBody(messages, subject.Config.Model, false)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body) testValidHTTPResponse(subject, history, body, false)
})
it("ignores history when configured to do so", func() {
mockCaller.EXPECT().SetAPIKey(envApiKey).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{OmitHistory: true}, nil).Times(1)
subject, err := client.New(mockCaller, mockConfigStore, mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
// Read and Write are never called on the history store
mockHistoryStore.EXPECT().Read().Times(0)
mockHistoryStore.EXPECT().Write(gomock.Any()).Times(0)
messages = createMessages(nil, query)
body, err = createBody(messages, subject.Config.Model, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, nil, body, true)
}) })
it("truncates the history as expected", func() { it("truncates the history as expected", func() {
history := []types.Message{ history := []types.Message{
@@ -272,7 +289,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
body, err = createBody(messages, subject.Config.Model, false) body, err = createBody(messages, subject.Config.Model, false)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(subject, history, body) testValidHTTPResponse(subject, history, body, false)
}) })
}) })
}) })
@@ -461,6 +478,7 @@ type clientFactory struct {
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory { func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
mockConfigStore.EXPECT().ReadDefaults().Return(types.Config{ mockConfigStore.EXPECT().ReadDefaults().Return(types.Config{
Name: defaultName,
Model: defaultModel, Model: defaultModel,
MaxTokens: defaultMaxTokens, MaxTokens: defaultMaxTokens,
URL: defaultURL, URL: defaultURL,
@@ -476,17 +494,21 @@ func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStor
} }
func (f *clientFactory) buildClientWithoutConfig() *client.Client { func (f *clientFactory) buildClientWithoutConfig() *client.Client {
f.mockCaller.EXPECT().SetAPIKey(envApiKey).Times(1)
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1) f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
c := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore) c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
return c.WithCapacity(50) return c.WithCapacity(50)
} }
func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client { func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
f.mockCaller.EXPECT().SetAPIKey(envApiKey).Times(1)
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1) f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
c := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore) c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
return c.WithCapacity(50) return c.WithCapacity(50)
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/kardolus/chatgpt-cli/configmanager" "github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/history" "github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http" "github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/utils"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"io" "io"
@@ -25,6 +24,7 @@ var (
interactiveMode bool interactiveMode bool
listModels bool listModels bool
modelName string modelName string
maxTokens int
GitCommit string GitCommit string
GitVersion string GitVersion string
ServiceURL string ServiceURL string
@@ -46,8 +46,8 @@ func main() {
rootCmd.PersistentFlags().BoolVarP(&showConfig, "config", "c", false, "Display the configuration") rootCmd.PersistentFlags().BoolVarP(&showConfig, "config", "c", false, "Display the configuration")
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information") rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information")
rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models") rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models")
rootCmd.PersistentFlags().StringVarP(&modelName, "model", "m", "", "Use a custom GPT model by specifying the model name")
rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name") rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name")
rootCmd.PersistentFlags().IntVar(&maxTokens, "set-max-tokens", 0, "Set a new default max token size by specifying the max tokens")
viper.AutomaticEnv() viper.AutomaticEnv()
@@ -74,6 +74,16 @@ func run(cmd *cobra.Command, args []string) error {
return nil return nil
} }
if cmd.Flag("set-max-tokens").Changed {
cm := configmanager.New(config.New())
if err := cm.WriteMaxTokens(maxTokens); err != nil {
return err
}
fmt.Println("Max tokens successfully updated to", maxTokens)
return nil
}
if clearHistory { if clearHistory {
historyHandler := history.New() historyHandler := history.New()
err := historyHandler.Delete() err := historyHandler.Delete()
@@ -85,7 +95,7 @@ func run(cmd *cobra.Command, args []string) error {
} }
if showConfig { if showConfig {
cm := configmanager.New(config.New()) cm := configmanager.New(config.New()).WithEnvironment()
if c, err := cm.ShowConfig(); err != nil { if c, err := cm.ShowConfig(); err != nil {
return err return err
@@ -95,15 +105,9 @@ func run(cmd *cobra.Command, args []string) error {
return nil return nil
} }
secret := viper.GetString(utils.OpenAIKeyEnv) client, err := client.New(http.New(), config.New(), history.New())
if secret == "" { if err != nil {
return errors.New("missing environment variable: " + utils.OpenAIKeyEnv) return err
}
client := client.New(http.New().WithSecret(secret), config.New(), history.New())
if modelName != "" {
client = client.WithModel(modelName)
} }
if ServiceURL != "" { if ServiceURL != "" {

View File

@@ -9,6 +9,7 @@ import (
) )
const ( const (
openAIName = "openai"
openAIModel = "gpt-3.5-turbo" openAIModel = "gpt-3.5-turbo"
openAIModelMaxTokens = 4096 openAIModelMaxTokens = 4096
openAIURL = "https://api.openai.com" openAIURL = "https://api.openai.com"
@@ -47,6 +48,7 @@ func (f *FileIO) Read() (types.Config, error) {
func (f *FileIO) ReadDefaults() types.Config { func (f *FileIO) ReadDefaults() types.Config {
return types.Config{ return types.Config{
Name: openAIName,
Model: openAIModel, Model: openAIModel,
MaxTokens: openAIModelMaxTokens, MaxTokens: openAIModelMaxTokens,
URL: openAIURL, URL: openAIURL,

View File

@@ -4,6 +4,10 @@ import (
"github.com/kardolus/chatgpt-cli/config" "github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/types" "github.com/kardolus/chatgpt-cli/types"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"os"
"reflect"
"strconv"
"strings"
) )
type ConfigManager struct { type ConfigManager struct {
@@ -12,28 +16,23 @@ type ConfigManager struct {
} }
func New(cs config.ConfigStore) *ConfigManager { func New(cs config.ConfigStore) *ConfigManager {
c := cs.ReadDefaults() configuration := cs.ReadDefaults()
configured, err := cs.Read() userConfig, err := cs.Read()
if err == nil { if err == nil {
if configured.Model != "" { configuration = replaceByConfigFile(configuration, userConfig)
c.Model = configured.Model
}
if configured.MaxTokens != 0 {
c.MaxTokens = configured.MaxTokens
}
if configured.URL != "" {
c.URL = configured.URL
}
if configured.CompletionsPath != "" {
c.CompletionsPath = configured.CompletionsPath
}
if configured.ModelsPath != "" {
c.ModelsPath = configured.ModelsPath
}
} }
return &ConfigManager{configStore: cs, Config: c} return &ConfigManager{configStore: cs, Config: configuration}
}
func (c *ConfigManager) WithEnvironment() *ConfigManager {
c.Config = replaceByEnvironment(c.Config)
return c
}
func (c *ConfigManager) APIKeyEnvVarName() string {
return strings.ToUpper(c.Config.Name) + "_" + "API_KEY"
} }
func (c *ConfigManager) ShowConfig() (string, error) { func (c *ConfigManager) ShowConfig() (string, error) {
@@ -45,8 +44,72 @@ func (c *ConfigManager) ShowConfig() (string, error) {
return string(data), nil return string(data), nil
} }
func (c *ConfigManager) WriteMaxTokens(tokens int) error {
c.Config.MaxTokens = tokens
return c.configStore.Write(c.Config)
}
func (c *ConfigManager) WriteModel(model string) error { func (c *ConfigManager) WriteModel(model string) error {
c.Config.Model = model c.Config.Model = model
return c.configStore.Write(c.Config) return c.configStore.Write(c.Config)
} }
func replaceByConfigFile(defaultConfig, userConfig types.Config) types.Config {
t := reflect.TypeOf(defaultConfig)
vDefault := reflect.ValueOf(&defaultConfig).Elem()
vUser := reflect.ValueOf(userConfig)
for i := 0; i < t.NumField(); i++ {
defaultField := vDefault.Field(i)
userField := vUser.Field(i)
switch defaultField.Kind() {
case reflect.String:
if userStr := userField.String(); userStr != "" {
defaultField.SetString(userStr)
}
case reflect.Int:
if userInt := int(userField.Int()); userInt != 0 {
defaultField.SetInt(int64(userInt))
}
case reflect.Bool:
if userBool := userField.Bool(); &userBool != nil {
defaultField.SetBool(userBool)
}
}
}
return defaultConfig
}
func replaceByEnvironment(configuration types.Config) types.Config {
t := reflect.TypeOf(configuration)
v := reflect.ValueOf(&configuration).Elem()
prefix := strings.ToUpper(configuration.Name) + "_"
for i := 0; i < t.NumField(); i++ {
tag := t.Field(i).Tag.Get("yaml")
if tag == "name" {
continue
}
if value := os.Getenv(prefix + strings.ToUpper(tag)); value != "" {
field := v.Field(i)
switch field.Kind() {
case reflect.String:
field.SetString(value)
case reflect.Int:
intValue, _ := strconv.Atoi(value)
field.SetInt(int64(intValue))
case reflect.Bool:
boolValue, _ := strconv.ParseBool(value)
field.SetBool(boolValue)
}
}
}
return configuration
}

View File

@@ -9,6 +9,9 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/sclevine/spec/report" "github.com/sclevine/spec/report"
"os"
"strconv"
"strings"
"testing" "testing"
) )
@@ -21,16 +24,20 @@ func TestUnitConfigManager(t *testing.T) {
func testConfig(t *testing.T, when spec.G, it spec.S) { func testConfig(t *testing.T, when spec.G, it spec.S) {
const ( const (
defaultMaxTokens = 10 defaultMaxTokens = 10
defaultName = "default-name"
defaultURL = "default-url" defaultURL = "default-url"
defaultModel = "default-model" defaultModel = "default-model"
defaultApiKey = "default-api-key"
defaultCompletionsPath = "default-completions-path" defaultCompletionsPath = "default-completions-path"
defaultModelsPath = "default-models-path" defaultModelsPath = "default-models-path"
defaultOmitHistory = false
) )
var ( var (
mockCtrl *gomock.Controller mockCtrl *gomock.Controller
mockConfigStore *MockConfigStore mockConfigStore *MockConfigStore
defaultConfig types.Config defaultConfig types.Config
envPrefix string
) )
it.Before(func() { it.Before(func() {
@@ -39,12 +46,17 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore = NewMockConfigStore(mockCtrl) mockConfigStore = NewMockConfigStore(mockCtrl)
defaultConfig = types.Config{ defaultConfig = types.Config{
Name: defaultName,
APIKey: defaultApiKey,
Model: defaultModel, Model: defaultModel,
MaxTokens: defaultMaxTokens, MaxTokens: defaultMaxTokens,
URL: defaultURL, URL: defaultURL,
CompletionsPath: defaultCompletionsPath, CompletionsPath: defaultCompletionsPath,
ModelsPath: defaultModelsPath, ModelsPath: defaultModelsPath,
OmitHistory: defaultOmitHistory,
} }
envPrefix = strings.ToUpper(defaultConfig.Name) + "_"
}) })
it.After(func() { it.After(func() {
@@ -52,17 +64,27 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
}) })
when("Constructing a new ConfigManager", func() { when("Constructing a new ConfigManager", func() {
it.Before(func() {
cleanEnv(envPrefix)
})
it.After(func() {
cleanEnv(envPrefix)
})
it("applies the default configuration when user config is missing", func() { it("applies the default configuration when user config is missing", func() {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New("no such file")).Times(1) mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New("no such file")).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Model).To(Equal(defaultModel)) Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens)) Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL)) Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
}) })
it("gives precedence to the user provided model", func() { it("gives precedence to the user provided model", func() {
userModel := "the-model" userModel := "the-model"
@@ -70,13 +92,33 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{Model: userModel}, nil).Times(1) mockConfigStore.EXPECT().Read().Return(types.Config{Model: userModel}, nil).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Model).To(Equal(userModel)) Expect(subject.Config.Model).To(Equal(userModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens)) Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL)) Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the user provided name", func() {
userName := "the-name"
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{Name: userName}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.Name).To(Equal(userName))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
}) })
it("gives precedence to the user provided max-tokens", func() { it("gives precedence to the user provided max-tokens", func() {
userMaxTokens := 42 userMaxTokens := 42
@@ -84,13 +126,17 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{MaxTokens: userMaxTokens}, nil).Times(1) mockConfigStore.EXPECT().Read().Return(types.Config{MaxTokens: userMaxTokens}, nil).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel)) Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(userMaxTokens)) Expect(subject.Config.MaxTokens).To(Equal(userMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL)) Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
}) })
it("gives precedence to the user provided URL", func() { it("gives precedence to the user provided URL", func() {
userURL := "the-user-url" userURL := "the-user-url"
@@ -98,13 +144,16 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{URL: userURL}, nil).Times(1) mockConfigStore.EXPECT().Read().Return(types.Config{URL: userURL}, nil).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel)) Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens)) Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(userURL)) Expect(subject.Config.URL).To(Equal(userURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
}) })
it("gives precedence to the user provided completions-path", func() { it("gives precedence to the user provided completions-path", func() {
completionsPath := "the-completions-path" completionsPath := "the-completions-path"
@@ -112,13 +161,16 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{CompletionsPath: completionsPath}, nil).Times(1) mockConfigStore.EXPECT().Read().Return(types.Config{CompletionsPath: completionsPath}, nil).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel)) Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens)) Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL)) Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(completionsPath)) Expect(subject.Config.CompletionsPath).To(Equal(completionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
}) })
it("gives precedence to the user provided models-path", func() { it("gives precedence to the user provided models-path", func() {
modelsPath := "the-models-path" modelsPath := "the-models-path"
@@ -126,13 +178,204 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{ModelsPath: modelsPath}, nil).Times(1) mockConfigStore.EXPECT().Read().Return(types.Config{ModelsPath: modelsPath}, nil).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel)) Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens)) Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL)) Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(modelsPath)) Expect(subject.Config.ModelsPath).To(Equal(modelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the user provided api-key", func() {
apiKey := "new-api-key"
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{APIKey: apiKey}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(apiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the user provided omit-history", func() {
omitHistory := true
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{OmitHistory: omitHistory}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(omitHistory))
})
it("gives precedence to the OMIT_HISTORY environment variable", func() {
var (
environmentValue = true
configValue = false
)
Expect(os.Setenv(envPrefix+"OMIT_HISTORY", strconv.FormatBool(environmentValue))).To(Succeed())
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{OmitHistory: configValue}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(environmentValue))
})
it("gives precedence to the API_KEY environment variable", func() {
var (
environmentKey = "environment-api-key"
configKey = "config-api-key"
)
Expect(os.Setenv(envPrefix+"API_KEY", environmentKey)).To(Succeed())
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{APIKey: configKey}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(environmentKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the MODEL environment variable", func() {
var (
envModel = "environment-model"
confModel = "config-model"
)
Expect(os.Setenv(envPrefix+"MODEL", envModel)).To(Succeed())
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{Model: confModel}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(envModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the MAX_TOKENS environment variable", func() {
var (
envMaxTokens = 42
confMaxTokens = 4242
)
Expect(os.Setenv(envPrefix+"MAX_TOKENS", strconv.Itoa(envMaxTokens))).To(Succeed())
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{MaxTokens: confMaxTokens}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(envMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the URL environment variable", func() {
var (
envURL = "environment-url"
confURL = "config-url"
)
Expect(os.Setenv(envPrefix+"URL", envURL)).To(Succeed())
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{URL: confURL}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(envURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the COMPLETIONS_PATH environment variable", func() {
var (
envCompletionsPath = "environment-completions-path"
confCompletionsPath = "config-completions-path"
)
Expect(os.Setenv(envPrefix+"COMPLETIONS_PATH", envCompletionsPath)).To(Succeed())
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{CompletionsPath: confCompletionsPath}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(envCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
})
it("gives precedence to the MODELS_PATH environment variable", func() {
var (
envModelsPath = "environment-models-path"
confModelsPath = "config-models-path"
)
Expect(os.Setenv(envPrefix+"MODELS_PATH", envModelsPath)).To(Succeed())
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{ModelsPath: confModelsPath}, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(envModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
}) })
}) })
@@ -141,15 +384,18 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(defaultConfig, nil).Times(1) mockConfigStore.EXPECT().Read().Return(defaultConfig, nil).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
result, err := subject.ShowConfig() result, err := subject.ShowConfig()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(result).To(ContainSubstring(defaultName))
Expect(result).To(ContainSubstring(defaultApiKey))
Expect(result).To(ContainSubstring(defaultModel)) Expect(result).To(ContainSubstring(defaultModel))
Expect(result).To(ContainSubstring(defaultURL)) Expect(result).To(ContainSubstring(defaultURL))
Expect(result).To(ContainSubstring(defaultCompletionsPath)) Expect(result).To(ContainSubstring(defaultCompletionsPath))
Expect(result).To(ContainSubstring(defaultModelsPath)) Expect(result).To(ContainSubstring(defaultModelsPath))
Expect(result).To(ContainSubstring(fmt.Sprintf("%d", defaultMaxTokens))) Expect(result).To(ContainSubstring(fmt.Sprintf("%d", defaultMaxTokens)))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
}) })
}) })
@@ -158,7 +404,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(defaultConfig, nil).Times(1) mockConfigStore.EXPECT().Read().Return(defaultConfig, nil).Times(1)
subject := configmanager.New(mockConfigStore) subject := configmanager.New(mockConfigStore).WithEnvironment()
modelName := "the-model" modelName := "the-model"
subject.Config.Model = modelName subject.Config.Model = modelName
@@ -167,4 +413,29 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.WriteModel(modelName)).To(Succeed()) Expect(subject.WriteModel(modelName)).To(Succeed())
}) })
}) })
when("WriteMaxTokens()", func() {
it("writes the expected config file", func() {
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(defaultConfig, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
maxTokens := 9879284
subject.Config.MaxTokens = maxTokens
mockConfigStore.EXPECT().Write(subject.Config).Times(1)
Expect(subject.WriteMaxTokens(maxTokens)).To(Succeed())
})
})
}
func cleanEnv(envPrefix string) {
Expect(os.Unsetenv(envPrefix + "API_KEY")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "MODEL")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "MAX_TOKENS")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "URL")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "COMPLETIONS_PATH")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "MODELS_PATH")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "OMIT_HISTORY")).To(Succeed())
} }

View File

@@ -26,6 +26,7 @@ const (
type Caller interface { type Caller interface {
Post(url string, body []byte, stream bool) ([]byte, error) Post(url string, body []byte, stream bool) ([]byte, error)
Get(url string) ([]byte, error) Get(url string) ([]byte, error)
SetAPIKey(secret string)
} }
type RestCaller struct { type RestCaller struct {
@@ -42,9 +43,8 @@ func New() *RestCaller {
} }
} }
func (r *RestCaller) WithSecret(secret string) *RestCaller { func (r *RestCaller) SetAPIKey(secret string) {
r.secret = secret r.secret = secret
return r
} }
func (r *RestCaller) Get(url string) ([]byte, error) { func (r *RestCaller) Get(url string) ([]byte, error) {

View File

@@ -4,9 +4,9 @@ import (
"encoding/json" "encoding/json"
"github.com/kardolus/chatgpt-cli/client" "github.com/kardolus/chatgpt-cli/client"
"github.com/kardolus/chatgpt-cli/config" "github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/http" "github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types" "github.com/kardolus/chatgpt-cli/types"
"github.com/kardolus/chatgpt-cli/utils"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/sclevine/spec/report" "github.com/sclevine/spec/report"
@@ -27,10 +27,12 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
it.Before(func() { it.Before(func() {
RegisterTestingT(t) RegisterTestingT(t)
apiKey := os.Getenv(utils.OpenAIKeyEnv) apiKey := os.Getenv(configmanager.New(config.New()).WithEnvironment().APIKeyEnvVarName())
Expect(apiKey).NotTo(BeEmpty()) Expect(apiKey).NotTo(BeEmpty())
restCaller = http.New().WithSecret(apiKey) restCaller = http.New()
restCaller.SetAPIKey(apiKey)
defaults = config.New().ReadDefaults() defaults = config.New().ReadDefaults()
}) })

View File

@@ -3,9 +3,9 @@ package integration_test
import ( import (
"fmt" "fmt"
"github.com/kardolus/chatgpt-cli/config" "github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/history" "github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/types" "github.com/kardolus/chatgpt-cli/types"
"github.com/kardolus/chatgpt-cli/utils"
"github.com/onsi/gomega/gexec" "github.com/onsi/gomega/gexec"
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/sclevine/spec/report" "github.com/sclevine/spec/report"
@@ -13,6 +13,8 @@ import (
"os" "os"
"os/exec" "os/exec"
"path" "path"
"strconv"
"strings"
"testing" "testing"
"time" "time"
@@ -155,11 +157,13 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
const ( const (
exitSuccess = 0 exitSuccess = 0
exitFailure = 1 exitFailure = 1
apiKey = "some-key"
) )
var ( var (
homeDir string homeDir string
err error err error
apiKeyEnvVar string
) )
it.Before(func() { it.Before(func() {
@@ -176,8 +180,10 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
homeDir, err = os.MkdirTemp("", "mockHome") homeDir, err = os.MkdirTemp("", "mockHome")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
apiKeyEnvVar = configmanager.New(config.New()).WithEnvironment().APIKeyEnvVarName()
Expect(os.Setenv("HOME", homeDir)).To(Succeed()) Expect(os.Setenv("HOME", homeDir)).To(Succeed())
Expect(os.Setenv(utils.OpenAIKeyEnv, "some-key")).To(Succeed()) Expect(os.Setenv(apiKeyEnvVar, apiKey)).To(Succeed())
}) })
it.After(func() { it.After(func() {
@@ -186,7 +192,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
}) })
it("throws an error when the API key is missing", func() { it("throws an error when the API key is missing", func() {
Expect(os.Unsetenv(utils.OpenAIKeyEnv)).To(Succeed()) Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())
command := exec.Command(binaryPath, "some prompt") command := exec.Command(binaryPath, "some prompt")
session, err := gexec.Start(command, io.Discard, io.Discard) session, err := gexec.Start(command, io.Discard, io.Discard)
@@ -195,11 +201,11 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Eventually(session).Should(gexec.Exit(exitFailure)) Eventually(session).Should(gexec.Exit(exitFailure))
output := string(session.Out.Contents()) output := string(session.Out.Contents())
Expect(output).To(ContainSubstring(utils.OpenAIKeyEnv)) Expect(output).To(ContainSubstring(apiKeyEnvVar))
}) })
it("should not require an API key for the --version flag", func() { it("should not require an API key for the --version flag", func() {
Expect(os.Unsetenv(utils.OpenAIKeyEnv)).To(Succeed()) Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())
command := exec.Command(binaryPath, "--version") command := exec.Command(binaryPath, "--version")
session, err := gexec.Start(command, io.Discard, io.Discard) session, err := gexec.Start(command, io.Discard, io.Discard)
@@ -209,7 +215,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
}) })
it("should not require an API key for the --clear-history flag", func() { it("should not require an API key for the --clear-history flag", func() {
Expect(os.Unsetenv(utils.OpenAIKeyEnv)).To(Succeed()) Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())
command := exec.Command(binaryPath, "--clear-history") command := exec.Command(binaryPath, "--clear-history")
session, err := gexec.Start(command, io.Discard, io.Discard) session, err := gexec.Start(command, io.Discard, io.Discard)
@@ -229,8 +235,19 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(output).To(ContainSubstring("flag needs an argument: --set-model")) Expect(output).To(ContainSubstring("flag needs an argument: --set-model"))
}) })
it("should require an argument for the --set-max-tokens flag", func() {
command := exec.Command(binaryPath, "--set-max-tokens")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitFailure))
output := string(session.Out.Contents())
Expect(output).To(ContainSubstring("flag needs an argument: --set-max-tokens"))
})
it("should require the chatgpt-cli folder but not an API key for the --set-model flag", func() { it("should require the chatgpt-cli folder but not an API key for the --set-model flag", func() {
Expect(os.Unsetenv(utils.OpenAIKeyEnv)).To(Succeed()) Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())
command := exec.Command(binaryPath, "--set-model", "123") command := exec.Command(binaryPath, "--set-model", "123")
session, err := gexec.Start(command, io.Discard, io.Discard) session, err := gexec.Start(command, io.Discard, io.Discard)
@@ -240,7 +257,21 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
output := string(session.Out.Contents()) output := string(session.Out.Contents())
Expect(output).To(ContainSubstring(".chatgpt-cli/config.yaml: no such file or directory")) Expect(output).To(ContainSubstring(".chatgpt-cli/config.yaml: no such file or directory"))
Expect(output).NotTo(ContainSubstring(utils.OpenAIKeyEnv)) Expect(output).NotTo(ContainSubstring(apiKeyEnvVar))
})
it("should require the chatgpt-cli folder but not an API key for the --set-max-tokens flag", func() {
Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())
command := exec.Command(binaryPath, "--set-max-tokens", "789")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitFailure))
output := string(session.Out.Contents())
Expect(output).To(ContainSubstring(".chatgpt-cli/config.yaml: no such file or directory"))
Expect(output).NotTo(ContainSubstring(apiKeyEnvVar))
}) })
it("should return the expected result for the --version flag", func() { it("should return the expected result for the --version flag", func() {
@@ -329,9 +360,34 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
// History should no longer exist // History should no longer exist
Expect(historyFile).NotTo(BeAnExistingFile()) Expect(historyFile).NotTo(BeAnExistingFile())
// environment takes precedence
omitHistoryEnvKey := strings.Replace(apiKeyEnvVar, "API_KEY", "OMIT_HISTORY", 1)
envValue := "true"
Expect(os.Setenv(omitHistoryEnvKey, envValue)).To(Succeed())
// Perform a query
command = exec.Command(binaryPath, "--query", "some-query")
session, err = gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
// The CLI response should be as expected
Eventually(session).Should(gexec.Exit(exitSuccess))
output = string(session.Out.Contents())
response = `I don't have personal opinions about bars, but here are some popular bars in Red Hook, Brooklyn:`
Expect(output).To(ContainSubstring(response))
// The history file should NOT exist
Expect(historyFile).NotTo(BeAnExistingFile())
}) })
it("has a configurable default model", func() { it("has a configurable default model", func() {
oldModel := "gpt-3.5-turbo"
newModel := "gpt-3.5-turbo-0301"
// config.yaml should not exist yet // config.yaml should not exist yet
configFile := path.Join(filePath, "config.yaml") configFile := path.Join(filePath, "config.yaml")
Expect(configFile).NotTo(BeAnExistingFile()) Expect(configFile).NotTo(BeAnExistingFile())
@@ -346,8 +402,8 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
output := string(session.Out.Contents()) output := string(session.Out.Contents())
// see models.json // see models.json
Expect(output).To(ContainSubstring("* gpt-3.5-turbo (current)")) Expect(output).To(ContainSubstring(fmt.Sprintf("* %s (current)", oldModel)))
Expect(output).To(ContainSubstring("- gpt-3.5-turbo-0301")) Expect(output).To(ContainSubstring(fmt.Sprintf("- %s", newModel)))
// --config displays the default model as well // --config displays the default model as well
command = exec.Command(binaryPath, "--config") command = exec.Command(binaryPath, "--config")
@@ -358,10 +414,10 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
output = string(session.Out.Contents()) output = string(session.Out.Contents())
Expect(output).To(ContainSubstring("gpt-3.5-turbo")) Expect(output).To(ContainSubstring(oldModel))
// Set the model // Set the model
command = exec.Command(binaryPath, "--set-model", "gpt-3.5-turbo-0301") command = exec.Command(binaryPath, "--set-model", newModel)
session, err = gexec.Start(command, io.Discard, io.Discard) session, err = gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@@ -371,6 +427,14 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
// config.yaml should have been created // config.yaml should have been created
Expect(configFile).To(BeAnExistingFile()) Expect(configFile).To(BeAnExistingFile())
contentBytes, err := os.ReadFile(configFile)
Expect(err).ShouldNot(HaveOccurred())
// config.yaml should have the expected content
content := string(contentBytes)
Expect(content).NotTo(ContainSubstring(apiKey))
Expect(content).To(ContainSubstring(newModel))
// --list-models shows the new model as default // --list-models shows the new model as default
command = exec.Command(binaryPath, "--list-models") command = exec.Command(binaryPath, "--list-models")
session, err = gexec.Start(command, io.Discard, io.Discard) session, err = gexec.Start(command, io.Discard, io.Discard)
@@ -380,8 +444,8 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
output = string(session.Out.Contents()) output = string(session.Out.Contents())
Expect(output).To(ContainSubstring("- gpt-3.5-turbo")) Expect(output).To(ContainSubstring(fmt.Sprintf("- %s", oldModel)))
Expect(output).To(ContainSubstring("* gpt-3.5-turbo-0301 (current)")) Expect(output).To(ContainSubstring(fmt.Sprintf("* %s (current)", newModel)))
// --config displays the new model as well // --config displays the new model as well
command = exec.Command(binaryPath, "--config") command = exec.Command(binaryPath, "--config")
@@ -391,8 +455,89 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Eventually(session).Should(gexec.Exit(exitSuccess)) Eventually(session).Should(gexec.Exit(exitSuccess))
output = string(session.Out.Contents()) output = string(session.Out.Contents())
Expect(output).To(ContainSubstring(newModel))
Expect(output).To(ContainSubstring("gpt-3.5-turbo-0301")) // environment takes precedence
modelEnvKey := strings.Replace(apiKeyEnvVar, "API_KEY", "MODEL", 1)
envModel := "new-model"
Expect(os.Setenv(modelEnvKey, envModel)).To(Succeed())
command = exec.Command(binaryPath, "--config")
session, err = gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitSuccess))
output = string(session.Out.Contents())
Expect(output).To(ContainSubstring(envModel))
Expect(os.Unsetenv(modelEnvKey)).To(Succeed())
})
it("has a configurable default max-tokens", func() {
defaults := config.New().ReadDefaults()
// config.yaml should not exist yet
configFile := path.Join(filePath, "config.yaml")
Expect(configFile).NotTo(BeAnExistingFile())
// --config displays the default max-tokens as well
command := exec.Command(binaryPath, "--config")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitSuccess))
output := string(session.Out.Contents())
Expect(output).To(ContainSubstring(strconv.Itoa(defaults.MaxTokens)))
// Set the max-tokens
newMaxTokens := "81724"
command = exec.Command(binaryPath, "--set-max-tokens", newMaxTokens)
session, err = gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
// The CLI response should be as expected
Eventually(session).Should(gexec.Exit(exitSuccess))
// config.yaml should have been created
Expect(configFile).To(BeAnExistingFile())
contentBytes, err := os.ReadFile(configFile)
Expect(err).ShouldNot(HaveOccurred())
// config.yaml should have the expected content
content := string(contentBytes)
Expect(content).NotTo(ContainSubstring(apiKey))
Expect(content).To(ContainSubstring(newMaxTokens))
// --config displays the new max-tokens as well
command = exec.Command(binaryPath, "--config")
session, err = gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitSuccess))
output = string(session.Out.Contents())
Expect(output).To(ContainSubstring(newMaxTokens))
// environment takes precedence
modelEnvKey := strings.Replace(apiKeyEnvVar, "API_KEY", "MAX_TOKENS", 1)
Expect(os.Setenv(modelEnvKey, newMaxTokens)).To(Succeed())
command = exec.Command(binaryPath, "--config")
session, err = gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitSuccess))
output = string(session.Out.Contents())
Expect(output).To(ContainSubstring(newMaxTokens))
Expect(os.Unsetenv(modelEnvKey)).To(Succeed())
}) })
}) })
}) })

View File

@@ -1,9 +1,12 @@
package types package types
type Config struct { type Config struct {
Name string `yaml:"name"`
APIKey string `yaml:"api_key"`
Model string `yaml:"model"` Model string `yaml:"model"`
MaxTokens int `yaml:"max_tokens"` MaxTokens int `yaml:"max_tokens"`
URL string `yaml:"url"` URL string `yaml:"url"`
CompletionsPath string `yaml:"completions_path"` CompletionsPath string `yaml:"completions_path"`
ModelsPath string `yaml:"models_path"` ModelsPath string `yaml:"models_path"`
OmitHistory bool `yaml:"omit_history"`
} }

View File

@@ -1,3 +0,0 @@
package utils
const OpenAIKeyEnv = "OPENAI_API_KEY"