mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add auto-create new thread feature in interactive mode
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/kardolus/chatgpt-cli/config"
|
||||
"github.com/kardolus/chatgpt-cli/configmanager"
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
@@ -20,6 +21,7 @@ const (
|
||||
MaxTokenBufferPercentage = 20
|
||||
SystemRole = "system"
|
||||
UserRole = "user"
|
||||
InteractiveThreadPrefix = "int_"
|
||||
gptPrefix = "gpt"
|
||||
)
|
||||
|
||||
@@ -30,7 +32,7 @@ type Client struct {
|
||||
historyStore history.HistoryStore
|
||||
}
|
||||
|
||||
func New(callerFactory http.CallerFactory, cs config.ConfigStore, hs history.HistoryStore) (*Client, error) {
|
||||
func New(callerFactory http.CallerFactory, cs config.ConfigStore, hs history.HistoryStore, interactiveMode bool) (*Client, error) {
|
||||
cm := configmanager.New(cs).WithEnvironment()
|
||||
configuration := cm.Config
|
||||
|
||||
@@ -40,7 +42,11 @@ func New(callerFactory http.CallerFactory, cs config.ConfigStore, hs history.His
|
||||
|
||||
caller := callerFactory(configuration)
|
||||
|
||||
hs.SetThread(configuration.Thread)
|
||||
if interactiveMode && cm.Config.AutoCreateNewThread {
|
||||
hs.SetThread(generateUniqueSlug())
|
||||
} else {
|
||||
hs.SetThread(configuration.Thread)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
Config: configuration,
|
||||
@@ -306,3 +312,8 @@ func createMessagesFromString(input string) []types.Message {
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func generateUniqueSlug() string {
|
||||
guid := uuid.New()
|
||||
return InteractiveThreadPrefix + guid.String()[:4]
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
defaultTopP = 2.2
|
||||
defaultFrequencyPenalty = 3.3
|
||||
defaultPresencePenalty = 4.4
|
||||
defaultInteractiveMode = false
|
||||
envApiKey = "api-key"
|
||||
)
|
||||
|
||||
@@ -78,11 +79,66 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
|
||||
|
||||
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore)
|
||||
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, defaultInteractiveMode)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring(apiKeyEnvVar))
|
||||
})
|
||||
it("should set a unique thread slug in interactive mode when AutoCreateNewThread is true", func() {
|
||||
mockConfigStore.EXPECT().Read().Return(types.Config{
|
||||
AutoCreateNewThread: true,
|
||||
Thread: defaultThread,
|
||||
}, nil).Times(1)
|
||||
|
||||
var capturedThread string
|
||||
mockHistoryStore.EXPECT().SetThread(gomock.Any()).DoAndReturn(func(thread string) {
|
||||
capturedThread = thread
|
||||
}).Times(1)
|
||||
|
||||
interactiveMode := true
|
||||
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, interactiveMode)
|
||||
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(capturedThread).To(HavePrefix(client.InteractiveThreadPrefix)) // Assuming `InteractiveThreadPrefix` is "int_"
|
||||
Expect(len(capturedThread)).To(Equal(8)) // "int_" (4 chars) + 4 random characters
|
||||
})
|
||||
it("should not overwrite the thread in interactive mode when AutoCreateNewThread is false", func() {
|
||||
mockConfigStore.EXPECT().Read().Return(types.Config{
|
||||
AutoCreateNewThread: false,
|
||||
Thread: defaultThread,
|
||||
}, nil).Times(1)
|
||||
|
||||
var capturedThread string
|
||||
mockHistoryStore.EXPECT().SetThread(defaultThread).DoAndReturn(func(thread string) {
|
||||
capturedThread = thread
|
||||
}).Times(1)
|
||||
|
||||
interactiveMode := true
|
||||
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, interactiveMode)
|
||||
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(capturedThread).To(Equal(defaultThread))
|
||||
})
|
||||
it("should never overwrite the thread in non-interactive mode", func() {
|
||||
mockConfigStore.EXPECT().Read().Return(types.Config{
|
||||
AutoCreateNewThread: true,
|
||||
Thread: defaultThread,
|
||||
}, nil).Times(1)
|
||||
|
||||
var capturedThread string
|
||||
mockHistoryStore.EXPECT().SetThread(defaultThread).DoAndReturn(func(thread string) {
|
||||
capturedThread = thread
|
||||
}).Times(1)
|
||||
|
||||
interactiveMode := false
|
||||
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, interactiveMode)
|
||||
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(capturedThread).To(Equal(defaultThread))
|
||||
})
|
||||
})
|
||||
|
||||
when("Query()", func() {
|
||||
@@ -260,7 +316,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
|
||||
mockConfigStore.EXPECT().Read().Return(types.Config{OmitHistory: true}, nil).Times(1)
|
||||
|
||||
subject, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore)
|
||||
subject, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore, defaultInteractiveMode)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Read and Write are never called on the history store
|
||||
@@ -549,7 +605,7 @@ func (f *clientFactory) buildClientWithoutConfig() *client.Client {
|
||||
f.mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
|
||||
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
|
||||
|
||||
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
|
||||
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore, defaultInteractiveMode)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
return c.WithContextWindow(defaultContextWindow)
|
||||
@@ -559,7 +615,7 @@ func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Clien
|
||||
f.mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
|
||||
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
|
||||
|
||||
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
|
||||
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore, defaultInteractiveMode)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
return c.WithContextWindow(defaultContextWindow)
|
||||
|
||||
@@ -34,6 +34,20 @@ func (m *MockHistoryStore) EXPECT() *MockHistoryStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetThread mocks base method.
|
||||
func (m *MockHistoryStore) GetThread() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetThread")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetThread indicates an expected call of GetThread.
|
||||
func (mr *MockHistoryStoreMockRecorder) GetThread() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThread", reflect.TypeOf((*MockHistoryStore)(nil).GetThread))
|
||||
}
|
||||
|
||||
// Read mocks base method.
|
||||
func (m *MockHistoryStore) Read() ([]types.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Reference in New Issue
Block a user