From 9366538ee9b6af1af38f677f770e14af80cf4109 Mon Sep 17 00:00:00 2001 From: kardolus Date: Sun, 11 Feb 2024 15:24:04 -0600 Subject: [PATCH] Add --list-threads --- client/configmocks_test.go | 15 ++++++++ cmd/chatgpt/main.go | 16 ++++++++ config/store.go | 33 ++++++++++++++-- configmanager/configmanager.go | 32 ++++++++++++++++ configmanager/configmanager_test.go | 39 +++++++++++++++++++ configmanager/configmocks_test.go | 15 ++++++++ history/store.go | 19 +++------- integration/integration_test.go | 59 ++++++++++++++++++++++++++++- utils/utils.go | 16 +++++++- 9 files changed, 224 insertions(+), 20 deletions(-) diff --git a/client/configmocks_test.go b/client/configmocks_test.go index 315745e..99b36a7 100644 --- a/client/configmocks_test.go +++ b/client/configmocks_test.go @@ -34,6 +34,21 @@ func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder { return m.recorder } +// List mocks base method. +func (m *MockConfigStore) List() ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockConfigStoreMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockConfigStore)(nil).List)) +} + // Read mocks base method. func (m *MockConfigStore) Read() (types.Config, error) { m.ctrl.T.Helper() diff --git a/cmd/chatgpt/main.go b/cmd/chatgpt/main.go index 8df5047..47475ca 100644 --- a/cmd/chatgpt/main.go +++ b/cmd/chatgpt/main.go @@ -24,6 +24,7 @@ var ( showConfig bool interactiveMode bool listModels bool + listThreads bool modelName string threadName string maxTokens int @@ -50,6 +51,7 @@ func main() { rootCmd.PersistentFlags().BoolVarP(&showConfig, "config", "c", false, "Display the configuration") rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information") rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models") + rootCmd.PersistentFlags().BoolVarP(&listThreads, "list-threads", "", false, "List available threads") rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name") rootCmd.PersistentFlags().StringVar(&threadName, "set-thread", "", "Set a new active thread by specifying the thread name") rootCmd.PersistentFlags().IntVar(&maxTokens, "set-max-tokens", 0, "Set a new default max token size by specifying the max tokens") @@ -159,6 +161,20 @@ func run(cmd *cobra.Command, args []string) error { return nil } + if listThreads { + cm := configmanager.New(config.New()) + + threads, err := cm.ListThreads() + if err != nil { + return err + } + fmt.Println("Available threads:") + for _, thread := range threads { + fmt.Println(thread) + } + return nil + } + if interactiveMode { fmt.Println("Entering interactive mode. Type 'exit' and press Enter or press Ctrl+C to quit.") diff --git a/config/store.go b/config/store.go index 811432d..c0715af 100644 --- a/config/store.go +++ b/config/store.go @@ -26,6 +26,7 @@ const ( ) type ConfigStore interface { + List() ([]string, error) Read() (types.Config, error) ReadDefaults() types.Config Write(types.Config) error @@ -35,21 +36,45 @@ type ConfigStore interface { var _ ConfigStore = &FileIO{} type FileIO struct { - configFilePath string + configFilePath string + historyFilePath string } func New() *FileIO { - path, _ := getPath() + configPath, _ := getPath() + historyPath, _ := utils.GetHistoryDir() + return &FileIO{ - configFilePath: path, + configFilePath: configPath, + historyFilePath: historyPath, } } -func (f *FileIO) WithFilePath(configFilePath string) *FileIO { +func (f *FileIO) WithConfigPath(configFilePath string) *FileIO { f.configFilePath = configFilePath return f } +func (f *FileIO) WithHistoryPath(historyPath string) *FileIO { + f.historyFilePath = historyPath + return f +} + +func (f *FileIO) List() ([]string, error) { + var result []string + + files, err := os.ReadDir(f.historyFilePath) + if err != nil { + return nil, err + } + + for _, file := range files { + result = append(result, file.Name()) + } + + return result, nil +} + func (f *FileIO) Read() (types.Config, error) { return parseFile(f.configFilePath) } diff --git a/configmanager/configmanager.go b/configmanager/configmanager.go index ee65937..9d8dc4e 100644 --- a/configmanager/configmanager.go +++ b/configmanager/configmanager.go @@ -1,6 +1,7 @@ package configmanager import ( + "fmt" "github.com/kardolus/chatgpt-cli/config" "github.com/kardolus/chatgpt-cli/types" "gopkg.in/yaml.v3" @@ -35,6 +36,31 @@ func (c *ConfigManager) APIKeyEnvVarName() string { return strings.ToUpper(c.Config.Name) + "_" + "API_KEY" } +// ListThreads retrieves a list of all threads stored in the configuration. +// It marks the current thread with an asterisk (*) and returns the list sorted alphabetically. +// If an error occurs while retrieving the threads from the config store, it returns the error. +func (c *ConfigManager) ListThreads() ([]string, error) { + var result []string + + threads, err := c.configStore.List() + if err != nil { + return nil, err + } + + for _, thread := range threads { + thread = strings.ReplaceAll(thread, ".json", "") + if thread != c.Config.Thread { + result = append(result, fmt.Sprintf("- %s", thread)) + continue + } + result = append(result, fmt.Sprintf("* %s (current)", thread)) + } + + return result, nil +} + +// ShowConfig serializes the current configuration to a YAML string. +// It returns the serialized string or an error if the serialization fails. func (c *ConfigManager) ShowConfig() (string, error) { data, err := yaml.Marshal(c.Config) if err != nil { @@ -44,18 +70,24 @@ func (c *ConfigManager) ShowConfig() (string, error) { return string(data), nil } +// WriteMaxTokens updates the maximum number of tokens in the current configuration. +// It writes the updated configuration to the config store and returns an error if the write fails. func (c *ConfigManager) WriteMaxTokens(tokens int) error { c.Config.MaxTokens = tokens return c.configStore.Write(c.Config) } +// WriteModel updates the model in the current configuration. +// It writes the updated configuration to the config store and returns an error if the write fails. func (c *ConfigManager) WriteModel(model string) error { c.Config.Model = model return c.configStore.Write(c.Config) } +// WriteThread updates the current thread in the configuration. +// It writes the updated configuration to the config store and returns an error if the write fails. func (c *ConfigManager) WriteThread(thread string) error { c.Config.Thread = thread diff --git a/configmanager/configmanager_test.go b/configmanager/configmanager_test.go index 62712f2..23ab8d3 100644 --- a/configmanager/configmanager_test.go +++ b/configmanager/configmanager_test.go @@ -262,6 +262,45 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { subject.WriteThread(thread) }) }) + + when("ListThreads()", func() { + activeThread := "active-thread" + + it.Before(func() { + userConfig := types.Config{Thread: activeThread} + + mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes() + mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1) + }) + + it("throws an error when the List call fails", func() { + subject := configmanager.New(mockConfigStore).WithEnvironment() + + errorInstance := errors.New("an error occurred") + mockConfigStore.EXPECT().List().Return(nil, errorInstance).Times(1) + + _, err := subject.ListThreads() + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(errorInstance)) + }) + + it("returns the expected threads", func() { + subject := configmanager.New(mockConfigStore).WithEnvironment() + + threads := []string{"thread1.json", "thread2.json", activeThread + ".json"} + mockConfigStore.EXPECT().List().Return(threads, nil).Times(1) + + result, err := subject.ListThreads() + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(HaveLen(3)) + Expect(result[0]).NotTo(ContainSubstring("current")) + Expect(result[0]).NotTo(ContainSubstring("json")) + Expect(result[1]).NotTo(ContainSubstring("current")) + Expect(result[1]).NotTo(ContainSubstring("json")) + Expect(result[2]).To(ContainSubstring("current")) + Expect(result[2]).NotTo(ContainSubstring("json")) + }) + }) } func performWriteTest(mockConfigStore *MockConfigStore, defaultConfig types.Config, expectedValue interface{}, fieldName string, action func()) { diff --git a/configmanager/configmocks_test.go b/configmanager/configmocks_test.go index 19a269d..6aa2bf1 100644 --- a/configmanager/configmocks_test.go +++ b/configmanager/configmocks_test.go @@ -34,6 +34,21 @@ func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder { return m.recorder } +// List mocks base method. +func (m *MockConfigStore) List() ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockConfigStoreMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockConfigStore)(nil).List)) +} + // Read mocks base method. func (m *MockConfigStore) Read() (types.Config, error) { m.ctrl.T.Helper() diff --git a/history/store.go b/history/store.go index 19f61c7..09562b3 100644 --- a/history/store.go +++ b/history/store.go @@ -11,8 +11,7 @@ import ( ) const ( - historyDirName = "history" - jsonExtension = ".json" + jsonExtension = ".json" ) type HistoryStore interface { @@ -33,7 +32,7 @@ type FileIO struct { func New() (*FileIO, error) { _ = migrate() - dir, err := getHistoryDir() + dir, err := utils.GetHistoryDir() if err != nil { return nil, err } @@ -88,15 +87,6 @@ func (f *FileIO) getPath() string { return filepath.Join(f.historyDir, f.thread+jsonExtension) } -func getHistoryDir() (string, error) { - homeDir, err := utils.GetChatGPTDirectory() - if err != nil { - return "", err - } - - return filepath.Join(homeDir, historyDirName), nil -} - // migrate moves the legacy "history" file in ~/.chatgpt-cli to "history/default.json" func migrate() error { hiddenDir, err := utils.GetChatGPTDirectory() @@ -104,7 +94,10 @@ func migrate() error { return err } - historyFile := path.Join(hiddenDir, historyDirName) + historyFile, err := utils.GetHistoryDir() + if err != nil { + return err + } fileInfo, err := os.Stat(historyFile) if err != nil { diff --git a/integration/integration_test.go b/integration/integration_test.go index 95f257f..93b5902 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -15,6 +15,7 @@ import ( "os" "os/exec" "path" + "path/filepath" "strconv" "strings" "sync" @@ -102,10 +103,11 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { }) }) - when("Read, Write Config", func() { + when("Read, Write, List Config", func() { var ( tmpDir string tmpFile *os.File + historyDir string configIO *config.FileIO testConfig types.Config err error @@ -115,12 +117,15 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test") Expect(err).NotTo(HaveOccurred()) + historyDir, err = os.MkdirTemp(tmpDir, "history") + Expect(err).NotTo(HaveOccurred()) + tmpFile, err = os.CreateTemp(tmpDir, "config.yaml") Expect(err).NotTo(HaveOccurred()) Expect(tmpFile.Close()).To(Succeed()) - configIO = config.New().WithFilePath(tmpFile.Name()) + configIO = config.New().WithConfigPath(tmpFile.Name()).WithHistoryPath(historyDir) testConfig = types.Config{ Model: "test-model", @@ -145,6 +150,22 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { Expect(readConfig).To(Equal(testConfig)) }) + it("lists all the threads", func() { + files := []string{"thread1.json", "thread2.json", "thread3.json"} + + for _, file := range files { + file, err := os.Create(filepath.Join(historyDir, file)) + Expect(err).NotTo(HaveOccurred()) + + Expect(file.Close()).To(Succeed()) + } + + result, err := configIO.List() + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(HaveLen(3)) + Expect(result[2]).To(Equal("thread3.json")) + }) + // Since we don't have a Delete method in the config, we will test if we can overwrite the configuration. it("overwrites the existing config", func() { newConfig := types.Config{ @@ -255,6 +276,17 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { Expect(output).To(ContainSubstring(".chatgpt-cli: no such file or directory")) }) + it("should require a hidden folder for the --list-threads flag", func() { + command := exec.Command(binaryPath, "--list-threads") + session, err := gexec.Start(command, io.Discard, io.Discard) + Expect(err).NotTo(HaveOccurred()) + + Eventually(session).Should(gexec.Exit(exitFailure)) + + output := string(session.Out.Contents()) + Expect(output).To(ContainSubstring(".chatgpt-cli/history: no such file or directory")) + }) + it("should require an argument for the --set-model flag", func() { command := exec.Command(binaryPath, "--set-model") session, err := gexec.Start(command, io.Discard, io.Discard) @@ -468,6 +500,29 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { Expect(os.Unsetenv(omitHistoryEnvKey)).To(Succeed()) }) + it("should return the expected result for the --list-threads flag", func() { + historyDir := path.Join(filePath, "history") + Expect(os.Mkdir(historyDir, 0755)).To(Succeed()) + + files := []string{"thread1.json", "thread2.json", "thread3.json", "default.json"} + + os.Mkdir(historyDir, 7555) + + for _, file := range files { + file, err := os.Create(filepath.Join(historyDir, file)) + Expect(err).NotTo(HaveOccurred()) + + Expect(file.Close()).To(Succeed()) + } + + output := runCommand("--list-threads") + + Expect(output).To(ContainSubstring("* default (current)")) + Expect(output).To(ContainSubstring("- thread1")) + Expect(output).To(ContainSubstring("- thread2")) + Expect(output).To(ContainSubstring("- thread3")) + }) + when("configurable flags are set", func() { it.Before(func() { configFile = path.Join(filePath, "config.yaml") diff --git a/utils/utils.go b/utils/utils.go index 40a3e7b..9471ea3 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -5,10 +5,24 @@ import ( "path/filepath" ) +const ( + cliDirName = ".chatgpt-cli" + historyDirName = "history" +) + func GetChatGPTDirectory() (string, error) { homeDir, err := os.UserHomeDir() if err != nil { return "", err } - return filepath.Join(homeDir, ".chatgpt-cli"), nil + return filepath.Join(homeDir, cliDirName), nil +} + +func GetHistoryDir() (string, error) { + homeDir, err := GetChatGPTDirectory() + if err != nil { + return "", err + } + + return filepath.Join(homeDir, historyDirName), nil }