mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add history handling
This commit is contained in:
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Guillermo Kardolus
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
17
README.md
17
README.md
@@ -9,6 +9,8 @@ environment, demonstrating its practicality and effectiveness.
|
||||
|
||||
* Interactive streaming mode for real-time interaction with the GPT model.
|
||||
* Query mode for single input-output interactions with the GPT model.
|
||||
* Context management across CLI calls, enabling seamless conversations with the GPT model by maintaining message
|
||||
history.
|
||||
* Viper integration for robust configuration management.
|
||||
|
||||
## Development
|
||||
@@ -35,17 +37,22 @@ building the application:
|
||||
./bin/chatgpt what type of dog is a Jack Russel?
|
||||
```
|
||||
|
||||
4. To enable history tracking across CLI calls, create a ~/.chatgpt-cli directory using the command:
|
||||
|
||||
```shell
|
||||
mkdir ~/.chatgpt-cli
|
||||
```
|
||||
|
||||
With this directory in place, the CLI will automatically manage message history for seamless conversations with the GPT
|
||||
model. The history acts as a sliding window, maintaining a maximum of 4096 tokens to ensure optimal performance and
|
||||
interaction quality.
|
||||
|
||||
For more options, see:
|
||||
|
||||
```shell
|
||||
./bin/chatgpt --help
|
||||
```
|
||||
|
||||
## Up Next
|
||||
|
||||
* Maintain context across multiple calls to ChatGPT.
|
||||
* Reset the context with a CLI command.
|
||||
|
||||
## Useful Links
|
||||
|
||||
* [ChatGPT API Documentation](https://platform.openai.com/docs/introduction/overview)
|
||||
|
||||
@@ -4,22 +4,31 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
"github.com/kardolus/chatgpt-cli/http"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
const (
|
||||
model = "gpt-3.5-turbo"
|
||||
role = "user"
|
||||
URL = "https://api.openai.com/v1/chat/completions"
|
||||
AssistantContent = "You are a helpful assistant."
|
||||
AssistantRole = "assistant"
|
||||
GPTModel = "gpt-3.5-turbo"
|
||||
SystemRole = "system"
|
||||
URL = "https://api.openai.com/v1/chat/completions"
|
||||
UserRole = "user"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
caller http.Caller
|
||||
caller http.Caller
|
||||
readWriter history.Store
|
||||
history []types.Message
|
||||
}
|
||||
|
||||
func New(caller http.Caller) *Client {
|
||||
return &Client{caller: caller}
|
||||
func New(caller http.Caller, rw history.Store) *Client {
|
||||
return &Client{
|
||||
caller: caller,
|
||||
readWriter: rw,
|
||||
}
|
||||
}
|
||||
|
||||
// Query sends a query to the API and returns the response as a string.
|
||||
@@ -29,7 +38,9 @@ func New(caller http.Caller) *Client {
|
||||
// call using the Post method. If the response is not empty, it decodes the
|
||||
// response JSON and returns the content of the first choice.
|
||||
func (c *Client) Query(input string) (string, error) {
|
||||
body, err := CreateBody(input, false)
|
||||
c.initHistory(input)
|
||||
|
||||
body, err := c.createBody(false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -52,6 +63,8 @@ func (c *Client) Query(input string) (string, error) {
|
||||
return "", errors.New("no responses returned")
|
||||
}
|
||||
|
||||
c.updateHistory(response.Choices[0].Message.Content)
|
||||
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
@@ -61,35 +74,59 @@ func (c *Client) Query(input string) (string, error) {
|
||||
// input and then makes an API call using the Post method. The actual
|
||||
// processing of the streamed response is done in the Post method.
|
||||
func (c *Client) Stream(input string) error {
|
||||
body, err := CreateBody(input, true)
|
||||
c.initHistory(input)
|
||||
|
||||
body, err := c.createBody(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.caller.Post(URL, body, true)
|
||||
result, err := c.caller.Post(URL, body, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.updateHistory(string(result))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateBody(query string, stream bool) ([]byte, error) {
|
||||
message := types.Message{
|
||||
Role: role,
|
||||
Content: query,
|
||||
}
|
||||
|
||||
func (c *Client) createBody(stream bool) ([]byte, error) {
|
||||
body := types.Request{
|
||||
Model: model,
|
||||
Messages: []types.Message{message},
|
||||
Model: GPTModel,
|
||||
Messages: c.history,
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return json.Marshal(body)
|
||||
}
|
||||
|
||||
func (c *Client) initHistory(query string) {
|
||||
message := types.Message{
|
||||
Role: UserRole,
|
||||
Content: query,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
c.history, _ = c.readWriter.Read()
|
||||
if len(c.history) == 0 {
|
||||
c.history = []types.Message{{
|
||||
Role: SystemRole,
|
||||
Content: AssistantContent,
|
||||
}}
|
||||
}
|
||||
|
||||
// TODO Write history specific tests
|
||||
// TODO Write delete-specific tests (on store)
|
||||
// TODO Test the string returned from Stream
|
||||
// TODO implement sliding window
|
||||
|
||||
c.history = append(c.history, message)
|
||||
}
|
||||
|
||||
func (c *Client) updateHistory(response string) {
|
||||
c.history = append(c.history, types.Message{
|
||||
Role: AssistantRole,
|
||||
Content: response,
|
||||
})
|
||||
_ = c.readWriter.Write(c.history)
|
||||
}
|
||||
|
||||
@@ -14,11 +14,13 @@ import (
|
||||
"github.com/sclevine/spec/report"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
|
||||
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
|
||||
//go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
|
||||
|
||||
var (
|
||||
mockCtrl *gomock.Controller
|
||||
mockCaller *MockCaller
|
||||
mockStore *MockStore
|
||||
subject *client.Client
|
||||
)
|
||||
|
||||
@@ -31,8 +33,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
RegisterTestingT(t)
|
||||
mockCtrl = gomock.NewController(t)
|
||||
mockCaller = NewMockCaller(mockCtrl)
|
||||
mockStore = NewMockStore(mockCtrl)
|
||||
|
||||
subject = client.New(mockCaller)
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
subject = client.New(mockCaller, mockStore)
|
||||
})
|
||||
|
||||
it.After(func() {
|
||||
@@ -43,12 +47,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
const query = "test query"
|
||||
|
||||
var (
|
||||
err error
|
||||
body []byte
|
||||
body []byte
|
||||
messages []types.Message
|
||||
err error
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
body, err = client.CreateBody(query, false)
|
||||
messages = createMessages(nil, query)
|
||||
body, err = createBody(messages)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -56,14 +62,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
|
||||
|
||||
_, err = subject.Query(query)
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(errorMsg))
|
||||
})
|
||||
it("throws an error when the response is empty", func() {
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("empty response"))
|
||||
})
|
||||
@@ -71,7 +77,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
|
||||
})
|
||||
@@ -93,10 +99,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err.Error()).To(Equal("no responses returned"))
|
||||
})
|
||||
it("parses a valid http response as expected", func() {
|
||||
const answer = "content"
|
||||
|
||||
choice := types.Choice{
|
||||
Message: types.Message{
|
||||
Role: "role",
|
||||
Content: "content",
|
||||
Content: answer,
|
||||
},
|
||||
FinishReason: "",
|
||||
Index: 0,
|
||||
@@ -113,9 +121,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
|
||||
|
||||
mockStore.EXPECT().Write(append(messages, types.Message{
|
||||
Role: client.AssistantRole,
|
||||
Content: answer,
|
||||
}))
|
||||
|
||||
result, err := subject.Query(query)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).To(Equal("content"))
|
||||
Expect(result).To(Equal(answer))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func createBody(messages []types.Message) ([]byte, error) {
|
||||
req := types.Request{
|
||||
Model: client.GPTModel,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func createMessages(history []types.Message, query string) []types.Message {
|
||||
var messages []types.Message
|
||||
|
||||
if len(history) == 0 {
|
||||
messages = append(messages, types.Message{
|
||||
Role: client.SystemRole,
|
||||
Content: client.AssistantContent,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, types.Message{
|
||||
Role: client.UserRole,
|
||||
Content: query,
|
||||
})
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
78
client/iomocks_test.go
Normal file
78
client/iomocks_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store)
|
||||
|
||||
// Package client_test is a generated GoMock package.
|
||||
package client_test
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
types "github.com/kardolus/chatgpt-cli/types"
|
||||
)
|
||||
|
||||
// MockStore is a mock of Store interface.
|
||||
type MockStore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStoreMockRecorder
|
||||
}
|
||||
|
||||
// MockStoreMockRecorder is the mock recorder for MockStore.
|
||||
type MockStoreMockRecorder struct {
|
||||
mock *MockStore
|
||||
}
|
||||
|
||||
// NewMockStore creates a new mock instance.
|
||||
func NewMockStore(ctrl *gomock.Controller) *MockStore {
|
||||
mock := &MockStore{ctrl: ctrl}
|
||||
mock.recorder = &MockStoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockStore) Delete() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockStoreMockRecorder) Delete() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete))
|
||||
}
|
||||
|
||||
// Read mocks base method.
|
||||
func (m *MockStore) Read() ([]types.Message, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Read")
|
||||
ret0, _ := ret[0].([]types.Message)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read.
|
||||
func (mr *MockStoreMockRecorder) Read() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read))
|
||||
}
|
||||
|
||||
// Write mocks base method.
|
||||
func (m *MockStore) Write(arg0 []types.Message) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write.
|
||||
func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kardolus/chatgpt-cli/client"
|
||||
"github.com/kardolus/chatgpt-cli/history"
|
||||
"github.com/kardolus/chatgpt-cli/http"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
const secretEnv = "OPENAI_API_KEY"
|
||||
|
||||
var queryMode bool
|
||||
var clearHistory bool
|
||||
|
||||
func main() {
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -24,6 +26,7 @@ func main() {
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().BoolVarP(&queryMode, "query", "q", false, "Use query mode instead of stream mode")
|
||||
rootCmd.PersistentFlags().BoolVarP(&clearHistory, "clear-history", "c", false, "Clear the history of ChatGPT CLI")
|
||||
|
||||
viper.AutomaticEnv()
|
||||
|
||||
@@ -34,15 +37,28 @@ func main() {
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, args []string) error {
|
||||
if clearHistory {
|
||||
historyHandler := history.New()
|
||||
err := historyHandler.Delete()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("History successfully cleared.")
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
return errors.New("you must specify your query")
|
||||
if clearHistory {
|
||||
return nil
|
||||
} else {
|
||||
return errors.New("you must specify your query")
|
||||
}
|
||||
}
|
||||
|
||||
secret := viper.GetString(secretEnv)
|
||||
if secret == "" {
|
||||
return errors.New("missing environment variable: " + secretEnv)
|
||||
}
|
||||
client := client.New(http.New().WithSecret(secret))
|
||||
client := client.New(http.New().WithSecret(secret), history.New())
|
||||
|
||||
if queryMode {
|
||||
result, err := client.Query(strings.Join(args, " "))
|
||||
@@ -55,6 +71,5 @@ func run(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
85
history/store.go
Normal file
85
history/store.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package history
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
Delete() error
|
||||
Read() ([]types.Message, error)
|
||||
Write([]types.Message) error
|
||||
}
|
||||
|
||||
// Ensure RestCaller implements Caller interface
|
||||
var _ Store = &FileIO{}
|
||||
|
||||
type FileIO struct {
|
||||
}
|
||||
|
||||
func New() *FileIO {
|
||||
return &FileIO{}
|
||||
}
|
||||
|
||||
func (f *FileIO) Delete() error {
|
||||
historyFilePath, err := getPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(historyFilePath); err == nil {
|
||||
return os.Remove(historyFilePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FileIO) Read() ([]types.Message, error) {
|
||||
historyFilePath, err := getPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return parseFile(historyFilePath)
|
||||
}
|
||||
|
||||
func (f *FileIO) Write(messages []types.Message) error {
|
||||
historyFilePath, err := getPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(historyFilePath, data, 0644)
|
||||
}
|
||||
|
||||
func getPath() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(homeDir, ".chatgpt-cli", "history"), nil
|
||||
}
|
||||
|
||||
func parseFile(fileName string) ([]types.Message, error) {
|
||||
var result []types.Message
|
||||
|
||||
buf, err := ioutil.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(buf, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
10
http/http.go
10
http/http.go
@@ -55,8 +55,7 @@ func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error)
|
||||
|
||||
if response.StatusCode >= 200 && response.StatusCode < 300 {
|
||||
if stream {
|
||||
ProcessResponse(response.Body, os.Stdout)
|
||||
return nil, nil
|
||||
return ProcessResponse(response.Body, os.Stdout), nil
|
||||
} else {
|
||||
result, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
@@ -69,7 +68,9 @@ func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error)
|
||||
return nil, fmt.Errorf("http error: %d", response.StatusCode)
|
||||
}
|
||||
|
||||
func ProcessResponse(r io.Reader, w io.Writer) {
|
||||
func ProcessResponse(r io.Reader, w io.Writer) []byte {
|
||||
var result []byte
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -81,6 +82,7 @@ func ProcessResponse(r io.Reader, w io.Writer) {
|
||||
|
||||
if line == "[DONE]" {
|
||||
_, _ = w.Write([]byte("\n"))
|
||||
result = append(result, []byte("\n")...)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -94,8 +96,10 @@ func ProcessResponse(r io.Reader, w io.Writer) {
|
||||
for _, choice := range data.Choices {
|
||||
if content, ok := choice.Delta["content"]; ok {
|
||||
_, _ = w.Write([]byte(content))
|
||||
result = append(result, []byte(content)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user