Add auto-create new thread feature in interactive mode

This commit is contained in:
kardolus
2024-08-16 10:11:11 -04:00
parent 38a41279bf
commit e9756499a4
32 changed files with 1676 additions and 117 deletions

View File

@@ -7,6 +7,7 @@ import (
"strings"
"unicode/utf8"
"github.com/google/uuid"
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/configmanager"
"github.com/kardolus/chatgpt-cli/history"
@@ -20,6 +21,7 @@ const (
MaxTokenBufferPercentage = 20
SystemRole = "system"
UserRole = "user"
InteractiveThreadPrefix = "int_"
gptPrefix = "gpt"
)
@@ -30,7 +32,7 @@ type Client struct {
historyStore history.HistoryStore
}
func New(callerFactory http.CallerFactory, cs config.ConfigStore, hs history.HistoryStore) (*Client, error) {
func New(callerFactory http.CallerFactory, cs config.ConfigStore, hs history.HistoryStore, interactiveMode bool) (*Client, error) {
cm := configmanager.New(cs).WithEnvironment()
configuration := cm.Config
@@ -40,7 +42,11 @@ func New(callerFactory http.CallerFactory, cs config.ConfigStore, hs history.His
caller := callerFactory(configuration)
hs.SetThread(configuration.Thread)
if interactiveMode && cm.Config.AutoCreateNewThread {
hs.SetThread(generateUniqueSlug())
} else {
hs.SetThread(configuration.Thread)
}
return &Client{
Config: configuration,
@@ -306,3 +312,8 @@ func createMessagesFromString(input string) []types.Message {
return messages
}
func generateUniqueSlug() string {
guid := uuid.New()
return InteractiveThreadPrefix + guid.String()[:4]
}

View File

@@ -36,6 +36,7 @@ const (
defaultTopP = 2.2
defaultFrequencyPenalty = 3.3
defaultPresencePenalty = 4.4
defaultInteractiveMode = false
envApiKey = "api-key"
)
@@ -78,11 +79,66 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore)
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, defaultInteractiveMode)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(apiKeyEnvVar))
})
it("should set a unique thread slug in interactive mode when AutoCreateNewThread is true", func() {
mockConfigStore.EXPECT().Read().Return(types.Config{
AutoCreateNewThread: true,
Thread: defaultThread,
}, nil).Times(1)
var capturedThread string
mockHistoryStore.EXPECT().SetThread(gomock.Any()).DoAndReturn(func(thread string) {
capturedThread = thread
}).Times(1)
interactiveMode := true
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, interactiveMode)
Expect(err).NotTo(HaveOccurred())
Expect(capturedThread).To(HavePrefix(client.InteractiveThreadPrefix)) // Assuming `InteractiveThreadPrefix` is "int_"
Expect(len(capturedThread)).To(Equal(8)) // "int_" (4 chars) + 4 random characters
})
it("should not overwrite the thread in interactive mode when AutoCreateNewThread is false", func() {
mockConfigStore.EXPECT().Read().Return(types.Config{
AutoCreateNewThread: false,
Thread: defaultThread,
}, nil).Times(1)
var capturedThread string
mockHistoryStore.EXPECT().SetThread(defaultThread).DoAndReturn(func(thread string) {
capturedThread = thread
}).Times(1)
interactiveMode := true
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, interactiveMode)
Expect(err).NotTo(HaveOccurred())
Expect(capturedThread).To(Equal(defaultThread))
})
it("should never overwrite the thread in non-interactive mode", func() {
mockConfigStore.EXPECT().Read().Return(types.Config{
AutoCreateNewThread: true,
Thread: defaultThread,
}, nil).Times(1)
var capturedThread string
mockHistoryStore.EXPECT().SetThread(defaultThread).DoAndReturn(func(thread string) {
capturedThread = thread
}).Times(1)
interactiveMode := false
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, interactiveMode)
Expect(err).NotTo(HaveOccurred())
Expect(capturedThread).To(Equal(defaultThread))
})
})
when("Query()", func() {
@@ -260,7 +316,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{OmitHistory: true}, nil).Times(1)
subject, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore)
subject, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, defaultInteractiveMode)
Expect(err).NotTo(HaveOccurred())
// Read and Write are never called on the history store
@@ -549,7 +605,7 @@ func (f *clientFactory) buildClientWithoutConfig() *client.Client {
f.mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore, defaultInteractiveMode)
Expect(err).NotTo(HaveOccurred())
return c.WithContextWindow(defaultContextWindow)
@@ -559,7 +615,7 @@ func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Clien
f.mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore, defaultInteractiveMode)
Expect(err).NotTo(HaveOccurred())
return c.WithContextWindow(defaultContextWindow)

View File

@@ -34,6 +34,20 @@ func (m *MockHistoryStore) EXPECT() *MockHistoryStoreMockRecorder {
return m.recorder
}
// GetThread mocks base method.
func (m *MockHistoryStore) GetThread() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetThread")
ret0, _ := ret[0].(string)
return ret0
}
// GetThread indicates an expected call of GetThread.
func (mr *MockHistoryStoreMockRecorder) GetThread() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThread", reflect.TypeOf((*MockHistoryStore)(nil).GetThread))
}
// Read mocks base method.
func (m *MockHistoryStore) Read() ([]types.Message, error) {
m.ctrl.T.Helper()