Add --list-threads

This commit is contained in:
kardolus
2024-02-11 15:24:04 -06:00
parent f22df8793e
commit 9366538ee9
9 changed files with 224 additions and 20 deletions

View File

@@ -34,6 +34,21 @@ func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
return m.recorder
}
// List mocks base method.
func (m *MockConfigStore) List() ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockConfigStoreMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockConfigStore)(nil).List))
}
// Read mocks base method.
func (m *MockConfigStore) Read() (types.Config, error) {
m.ctrl.T.Helper()

View File

@@ -24,6 +24,7 @@ var (
showConfig bool
interactiveMode bool
listModels bool
listThreads bool
modelName string
threadName string
maxTokens int
@@ -50,6 +51,7 @@ func main() {
rootCmd.PersistentFlags().BoolVarP(&showConfig, "config", "c", false, "Display the configuration")
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information")
rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models")
rootCmd.PersistentFlags().BoolVarP(&listThreads, "list-threads", "", false, "List available threads")
rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name")
rootCmd.PersistentFlags().StringVar(&threadName, "set-thread", "", "Set a new active thread by specifying the thread name")
rootCmd.PersistentFlags().IntVar(&maxTokens, "set-max-tokens", 0, "Set a new default max token size by specifying the max tokens")
@@ -159,6 +161,20 @@ func run(cmd *cobra.Command, args []string) error {
return nil
}
if listThreads {
cm := configmanager.New(config.New())
threads, err := cm.ListThreads()
if err != nil {
return err
}
fmt.Println("Available threads:")
for _, thread := range threads {
fmt.Println(thread)
}
return nil
}
if interactiveMode {
fmt.Println("Entering interactive mode. Type 'exit' and press Enter or press Ctrl+C to quit.")

View File

@@ -26,6 +26,7 @@ const (
)
type ConfigStore interface {
List() ([]string, error)
Read() (types.Config, error)
ReadDefaults() types.Config
Write(types.Config) error
@@ -35,21 +36,45 @@ type ConfigStore interface {
var _ ConfigStore = &FileIO{}
type FileIO struct {
configFilePath string
configFilePath string
historyFilePath string
}
func New() *FileIO {
path, _ := getPath()
configPath, _ := getPath()
historyPath, _ := utils.GetHistoryDir()
return &FileIO{
configFilePath: path,
configFilePath: configPath,
historyFilePath: historyPath,
}
}
func (f *FileIO) WithFilePath(configFilePath string) *FileIO {
func (f *FileIO) WithConfigPath(configFilePath string) *FileIO {
f.configFilePath = configFilePath
return f
}
func (f *FileIO) WithHistoryPath(historyPath string) *FileIO {
f.historyFilePath = historyPath
return f
}
func (f *FileIO) List() ([]string, error) {
var result []string
files, err := os.ReadDir(f.historyFilePath)
if err != nil {
return nil, err
}
for _, file := range files {
result = append(result, file.Name())
}
return result, nil
}
func (f *FileIO) Read() (types.Config, error) {
return parseFile(f.configFilePath)
}

View File

@@ -1,6 +1,7 @@
package configmanager
import (
"fmt"
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/types"
"gopkg.in/yaml.v3"
@@ -35,6 +36,31 @@ func (c *ConfigManager) APIKeyEnvVarName() string {
return strings.ToUpper(c.Config.Name) + "_" + "API_KEY"
}
// ListThreads retrieves a list of all threads stored in the configuration.
// It marks the current thread with an asterisk (*) and returns the list sorted alphabetically.
// If an error occurs while retrieving the threads from the config store, it returns the error.
func (c *ConfigManager) ListThreads() ([]string, error) {
var result []string
threads, err := c.configStore.List()
if err != nil {
return nil, err
}
for _, thread := range threads {
thread = strings.ReplaceAll(thread, ".json", "")
if thread != c.Config.Thread {
result = append(result, fmt.Sprintf("- %s", thread))
continue
}
result = append(result, fmt.Sprintf("* %s (current)", thread))
}
return result, nil
}
// ShowConfig serializes the current configuration to a YAML string.
// It returns the serialized string or an error if the serialization fails.
func (c *ConfigManager) ShowConfig() (string, error) {
data, err := yaml.Marshal(c.Config)
if err != nil {
@@ -44,18 +70,24 @@ func (c *ConfigManager) ShowConfig() (string, error) {
return string(data), nil
}
// WriteMaxTokens updates the maximum number of tokens in the current configuration.
// It writes the updated configuration to the config store and returns an error if the write fails.
func (c *ConfigManager) WriteMaxTokens(tokens int) error {
c.Config.MaxTokens = tokens
return c.configStore.Write(c.Config)
}
// WriteModel updates the model in the current configuration.
// It writes the updated configuration to the config store and returns an error if the write fails.
func (c *ConfigManager) WriteModel(model string) error {
c.Config.Model = model
return c.configStore.Write(c.Config)
}
// WriteThread updates the current thread in the configuration.
// It writes the updated configuration to the config store and returns an error if the write fails.
func (c *ConfigManager) WriteThread(thread string) error {
c.Config.Thread = thread

View File

@@ -262,6 +262,45 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
subject.WriteThread(thread)
})
})
when("ListThreads()", func() {
activeThread := "active-thread"
it.Before(func() {
userConfig := types.Config{Thread: activeThread}
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)
})
it("throws an error when the List call fails", func() {
subject := configmanager.New(mockConfigStore).WithEnvironment()
errorInstance := errors.New("an error occurred")
mockConfigStore.EXPECT().List().Return(nil, errorInstance).Times(1)
_, err := subject.ListThreads()
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(errorInstance))
})
it("returns the expected threads", func() {
subject := configmanager.New(mockConfigStore).WithEnvironment()
threads := []string{"thread1.json", "thread2.json", activeThread + ".json"}
mockConfigStore.EXPECT().List().Return(threads, nil).Times(1)
result, err := subject.ListThreads()
Expect(err).NotTo(HaveOccurred())
Expect(result).To(HaveLen(3))
Expect(result[0]).NotTo(ContainSubstring("current"))
Expect(result[0]).NotTo(ContainSubstring("json"))
Expect(result[1]).NotTo(ContainSubstring("current"))
Expect(result[1]).NotTo(ContainSubstring("json"))
Expect(result[2]).To(ContainSubstring("current"))
Expect(result[2]).NotTo(ContainSubstring("json"))
})
})
}
func performWriteTest(mockConfigStore *MockConfigStore, defaultConfig types.Config, expectedValue interface{}, fieldName string, action func()) {

View File

@@ -34,6 +34,21 @@ func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
return m.recorder
}
// List mocks base method.
func (m *MockConfigStore) List() ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockConfigStoreMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockConfigStore)(nil).List))
}
// Read mocks base method.
func (m *MockConfigStore) Read() (types.Config, error) {
m.ctrl.T.Helper()

View File

@@ -11,8 +11,7 @@ import (
)
const (
historyDirName = "history"
jsonExtension = ".json"
jsonExtension = ".json"
)
type HistoryStore interface {
@@ -33,7 +32,7 @@ type FileIO struct {
func New() (*FileIO, error) {
_ = migrate()
dir, err := getHistoryDir()
dir, err := utils.GetHistoryDir()
if err != nil {
return nil, err
}
@@ -88,15 +87,6 @@ func (f *FileIO) getPath() string {
return filepath.Join(f.historyDir, f.thread+jsonExtension)
}
func getHistoryDir() (string, error) {
homeDir, err := utils.GetChatGPTDirectory()
if err != nil {
return "", err
}
return filepath.Join(homeDir, historyDirName), nil
}
// migrate moves the legacy "history" file in ~/.chatgpt-cli to "history/default.json"
func migrate() error {
hiddenDir, err := utils.GetChatGPTDirectory()
@@ -104,7 +94,10 @@ func migrate() error {
return err
}
historyFile := path.Join(hiddenDir, historyDirName)
historyFile, err := utils.GetHistoryDir()
if err != nil {
return err
}
fileInfo, err := os.Stat(historyFile)
if err != nil {

View File

@@ -15,6 +15,7 @@ import (
"os"
"os/exec"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
@@ -102,10 +103,11 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
})
})
when("Read, Write Config", func() {
when("Read, Write, List Config", func() {
var (
tmpDir string
tmpFile *os.File
historyDir string
configIO *config.FileIO
testConfig types.Config
err error
@@ -115,12 +117,15 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
tmpDir, err = os.MkdirTemp("", "chatgpt-cli-test")
Expect(err).NotTo(HaveOccurred())
historyDir, err = os.MkdirTemp(tmpDir, "history")
Expect(err).NotTo(HaveOccurred())
tmpFile, err = os.CreateTemp(tmpDir, "config.yaml")
Expect(err).NotTo(HaveOccurred())
Expect(tmpFile.Close()).To(Succeed())
configIO = config.New().WithFilePath(tmpFile.Name())
configIO = config.New().WithConfigPath(tmpFile.Name()).WithHistoryPath(historyDir)
testConfig = types.Config{
Model: "test-model",
@@ -145,6 +150,22 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(readConfig).To(Equal(testConfig))
})
it("lists all the threads", func() {
files := []string{"thread1.json", "thread2.json", "thread3.json"}
for _, file := range files {
file, err := os.Create(filepath.Join(historyDir, file))
Expect(err).NotTo(HaveOccurred())
Expect(file.Close()).To(Succeed())
}
result, err := configIO.List()
Expect(err).NotTo(HaveOccurred())
Expect(result).To(HaveLen(3))
Expect(result[2]).To(Equal("thread3.json"))
})
// Since we don't have a Delete method in the config, we will test if we can overwrite the configuration.
it("overwrites the existing config", func() {
newConfig := types.Config{
@@ -255,6 +276,17 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(output).To(ContainSubstring(".chatgpt-cli: no such file or directory"))
})
it("should require a hidden folder for the --list-threads flag", func() {
command := exec.Command(binaryPath, "--list-threads")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitFailure))
output := string(session.Out.Contents())
Expect(output).To(ContainSubstring(".chatgpt-cli/history: no such file or directory"))
})
it("should require an argument for the --set-model flag", func() {
command := exec.Command(binaryPath, "--set-model")
session, err := gexec.Start(command, io.Discard, io.Discard)
@@ -468,6 +500,29 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(os.Unsetenv(omitHistoryEnvKey)).To(Succeed())
})
it("should return the expected result for the --list-threads flag", func() {
historyDir := path.Join(filePath, "history")
Expect(os.Mkdir(historyDir, 0755)).To(Succeed())
files := []string{"thread1.json", "thread2.json", "thread3.json", "default.json"}
os.Mkdir(historyDir, 7555)
for _, file := range files {
file, err := os.Create(filepath.Join(historyDir, file))
Expect(err).NotTo(HaveOccurred())
Expect(file.Close()).To(Succeed())
}
output := runCommand("--list-threads")
Expect(output).To(ContainSubstring("* default (current)"))
Expect(output).To(ContainSubstring("- thread1"))
Expect(output).To(ContainSubstring("- thread2"))
Expect(output).To(ContainSubstring("- thread3"))
})
when("configurable flags are set", func() {
it.Before(func() {
configFile = path.Join(filePath, "config.yaml")

View File

@@ -5,10 +5,24 @@ import (
"path/filepath"
)
const (
cliDirName = ".chatgpt-cli"
historyDirName = "history"
)
func GetChatGPTDirectory() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(homeDir, ".chatgpt-cli"), nil
return filepath.Join(homeDir, cliDirName), nil
}
func GetHistoryDir() (string, error) {
homeDir, err := GetChatGPTDirectory()
if err != nil {
return "", err
}
return filepath.Join(homeDir, historyDirName), nil
}