Add --delete-thread flag for thread management

This commit is contained in:
kardolus
2024-04-13 16:53:49 -04:00
parent 321727f8c0
commit dcdba7b4c3
9 changed files with 151 additions and 59 deletions

View File

@@ -34,6 +34,20 @@ func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockConfigStore) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockConfigStoreMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockConfigStore)(nil).Delete), arg0)
}
// List mocks base method.
func (m *MockConfigStore) List() ([]string, error) {
m.ctrl.T.Helper()

View File

@@ -34,20 +34,6 @@ func (m *MockHistoryStore) EXPECT() *MockHistoryStoreMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockHistoryStore) Delete() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete")
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockHistoryStoreMockRecorder) Delete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockHistoryStore)(nil).Delete))
}
// Read mocks base method.
func (m *MockHistoryStore) Read() ([]types.Message, error) {
m.ctrl.T.Helper()

View File

@@ -57,6 +57,7 @@ func main() {
rootCmd.PersistentFlags().BoolVarP(&listThreads, "list-threads", "", false, "List available threads")
rootCmd.PersistentFlags().StringVar(&modelName, "set-model", "", "Set a new default GPT model by specifying the model name")
rootCmd.PersistentFlags().StringVar(&threadName, "set-thread", "", "Set a new active thread by specifying the thread name")
rootCmd.PersistentFlags().StringVar(&threadName, "delete-thread", "", "Delete the specified thread")
rootCmd.PersistentFlags().StringVar(&shell, "set-completions", "", "Generate autocompletion script for your current shell")
rootCmd.PersistentFlags().IntVar(&maxTokens, "set-max-tokens", 0, "Set a new default max token size by specifying the max tokens")
rootCmd.PersistentFlags().IntVar(&contextWindow, "set-context-window", 0, "Set a new default context window size")
@@ -120,6 +121,16 @@ func run(cmd *cobra.Command, args []string) error {
return nil
}
if cmd.Flag("delete-thread").Changed {
cm := configmanager.New(config.New())
if err := cm.DeleteThread(threadName); err != nil {
return err
}
fmt.Printf("Successfully deleted thead %s\n", threadName)
return nil
}
if listThreads {
cm := configmanager.New(config.New())
@@ -135,15 +146,9 @@ func run(cmd *cobra.Command, args []string) error {
}
if clearHistory {
historyHandler, err := history.New()
if err != nil {
return err
}
cm := configmanager.New(config.New())
historyHandler.SetThread(cm.Config.Thread)
if err := historyHandler.Delete(); err != nil {
if err := cm.DeleteThread(cm.Config.Thread); err != nil {
return err
}

View File

@@ -29,6 +29,7 @@ const (
)
type ConfigStore interface {
Delete(string) error
List() ([]string, error)
Read() (types.Config, error)
ReadDefaults() types.Config
@@ -63,6 +64,15 @@ func (f *FileIO) WithHistoryPath(historyPath string) *FileIO {
return f
}
func (f *FileIO) Delete(name string) error {
path := filepath.Join(f.historyFilePath, name+".json")
if _, err := os.Stat(path); err == nil {
return os.Remove(path)
}
return nil
}
func (f *FileIO) List() ([]string, error) {
var result []string

View File

@@ -36,6 +36,12 @@ func (c *ConfigManager) APIKeyEnvVarName() string {
return strings.ToUpper(c.Config.Name) + "_" + "API_KEY"
}
// DeleteThread removes the specified thread from the configuration store.
// This operation is idempotent; non-existent threads do not cause errors.
func (c *ConfigManager) DeleteThread(thread string) error {
return c.configStore.Delete(thread)
}
// ListThreads retrieves a list of all threads stored in the configuration.
// It marks the current thread with an asterisk (*) and returns the list sorted alphabetically.
// If an error occurs while retrieving the threads from the config store, it returns the error.

View File

@@ -126,8 +126,8 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
PresencePenalty: 5.5,
}
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().Read().Return(userConfig, nil).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
@@ -166,8 +166,8 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
os.Setenv(envPrefix+"FREQUENCY_PENALTY", "4.4")
os.Setenv(envPrefix+"PRESENCE_PENALTY", "5.5")
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New("config error")).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{}, errors.New("config error")).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
@@ -226,8 +226,8 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
PresencePenalty: 4.5,
}
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().Read().Return(userConfig, nil).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)
subject := configmanager.New(mockConfigStore).WithEnvironment()
@@ -281,13 +281,46 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
})
})
when("DeleteThread()", func() {
var subject *configmanager.ConfigManager
threadName := "non-active-thread"
it.Before(func() {
userConfig := types.Config{Thread: threadName}
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)
subject = configmanager.New(mockConfigStore).WithEnvironment()
})
it("propagates the error from the config store", func() {
expectedMsg := "expected-error"
mockConfigStore.EXPECT().Delete(threadName).Return(errors.New(expectedMsg)).Times(1)
err := subject.DeleteThread(threadName)
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(expectedMsg))
})
it("completes successfully the config store throws no error", func() {
mockConfigStore.EXPECT().Delete(threadName).Return(nil).Times(1)
err := subject.DeleteThread(threadName)
Expect(err).NotTo(HaveOccurred())
})
})
when("ListThreads()", func() {
activeThread := "active-thread"
it.Before(func() {
userConfig := types.Config{Thread: activeThread}
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).AnyTimes()
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(userConfig, nil).Times(1)
})

View File

@@ -34,6 +34,20 @@ func (m *MockConfigStore) EXPECT() *MockConfigStoreMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockConfigStore) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockConfigStoreMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockConfigStore)(nil).Delete), arg0)
}
// List mocks base method.
func (m *MockConfigStore) List() ([]string, error) {
m.ctrl.T.Helper()

View File

@@ -15,7 +15,6 @@ const (
)
type HistoryStore interface {
Delete() error
Read() ([]types.Message, error)
Write([]types.Message) error
SetThread(thread string)
@@ -63,13 +62,6 @@ func (f *FileIO) WithDirectory(historyDir string) *FileIO {
return f
}
func (f *FileIO) Delete() error {
if _, err := os.Stat(f.getPath()); err == nil {
return os.Remove(f.getPath())
}
return nil
}
func (f *FileIO) Read() ([]types.Message, error) {
return parseFile(f.getPath())
}

View File

@@ -46,7 +46,7 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
RegisterTestingT(t)
})
when("Read, Write and Delete History", func() {
when("Read and Write History", func() {
const threadName = "default-thread"
var (
@@ -93,17 +93,9 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
Expect(readMessages).To(Equal(messages))
})
it("deletes the file", func() {
err = fileIO.Delete()
Expect(err).NotTo(HaveOccurred())
_, err = os.Stat(threadName + ".json")
Expect(os.IsNotExist(err)).To(BeTrue())
})
})
when("Read, Write, List Config", func() {
when("Read, Write, List, Delete Config", func() {
var (
tmpDir string
tmpFile *os.File
@@ -236,6 +228,26 @@ max_tokens: 100
Expect(result[2]).To(Equal("thread3.json"))
})
it("deletes the thread", func() {
files := []string{"thread1.json", "thread2.json", "thread3.json"}
for _, file := range files {
file, err := os.Create(filepath.Join(historyDir, file))
Expect(err).NotTo(HaveOccurred())
Expect(file.Close()).To(Succeed())
}
err = configIO.Delete("thread2")
Expect(err).NotTo(HaveOccurred())
_, err = os.Stat(filepath.Join(historyDir, "thread2.json"))
Expect(os.IsNotExist(err)).To(BeTrue())
_, err = os.Stat(filepath.Join(historyDir, "thread3.json"))
Expect(os.IsNotExist(err)).To(BeFalse())
})
// Since we don't have a Delete method in the config, we will test if we can overwrite the configuration.
it("overwrites the existing config", func() {
newConfig := types.Config{
@@ -333,19 +345,6 @@ max_tokens: 100
Eventually(session).Should(gexec.Exit(exitSuccess))
})
it("should require a hidden folder for the --clear-history flag", func() {
Expect(os.Unsetenv(apiKeyEnvVar)).To(Succeed())
command := exec.Command(binaryPath, "--clear-history")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitFailure))
output := string(session.Out.Contents())
Expect(output).To(ContainSubstring(".chatgpt-cli: no such file or directory"))
})
it("should require a hidden folder for the --list-threads flag", func() {
command := exec.Command(binaryPath, "--list-threads")
session, err := gexec.Start(command, io.Discard, io.Discard)
@@ -631,6 +630,39 @@ max_tokens: 100
Expect(output).To(ContainSubstring("- thread3"))
})
it("should delete the expected thread using the --delete-threads flag", func() {
historyDir := path.Join(filePath, "history")
Expect(os.Mkdir(historyDir, 0755)).To(Succeed())
files := []string{"thread1.json", "thread2.json", "thread3.json", "default.json"}
os.Mkdir(historyDir, 7555)
for _, file := range files {
file, err := os.Create(filepath.Join(historyDir, file))
Expect(err).NotTo(HaveOccurred())
Expect(file.Close()).To(Succeed())
}
runCommand("--delete-thread", "thread2")
output := runCommand("--list-threads")
Expect(output).To(ContainSubstring("* default (current)"))
Expect(output).To(ContainSubstring("- thread1"))
Expect(output).NotTo(ContainSubstring("- thread2"))
Expect(output).To(ContainSubstring("- thread3"))
})
it("should not throw an error when a non-existent thread is deleted using the --delete-threads flag", func() {
command := exec.Command(binaryPath, "--delete-thread", "does-not-exist")
session, err := gexec.Start(command, io.Discard, io.Discard)
Expect(err).NotTo(HaveOccurred())
Eventually(session).Should(gexec.Exit(exitSuccess))
})
when("configurable flags are set", func() {
it.Before(func() {
configFile = path.Join(filePath, "config.yaml")