Add history handling

This commit is contained in:
kardolus
2023-05-03 01:03:09 -04:00
parent fd892219da
commit 5be0af6cb0
9 changed files with 330 additions and 42 deletions

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Guillermo Kardolus
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -9,6 +9,8 @@ environment, demonstrating its practicality and effectiveness.
* Interactive streaming mode for real-time interaction with the GPT model.
* Query mode for single input-output interactions with the GPT model.
* Context management across CLI calls, enabling seamless conversations with the GPT model by maintaining message
history.
* Viper integration for robust configuration management.
## Development
@@ -35,17 +37,22 @@ building the application:
./bin/chatgpt what type of dog is a Jack Russel?
```
4. To enable history tracking across CLI calls, create a ~/.chatgpt-cli directory using the command:
```shell
mkdir ~/.chatgpt-cli
```
With this directory in place, the CLI will automatically manage message history for seamless conversations with the GPT
model. The history acts as a sliding window, maintaining a maximum of 4096 tokens to ensure optimal performance and
interaction quality.
For more options, see:
```shell
./bin/chatgpt --help
```
## Up Next
* Maintain context across multiple calls to ChatGPT.
* Reset the context with a CLI command.
## Useful Links
* [ChatGPT API Documentation](https://platform.openai.com/docs/introduction/overview)

View File

@@ -4,22 +4,31 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types"
)
const (
model = "gpt-3.5-turbo"
role = "user"
URL = "https://api.openai.com/v1/chat/completions"
AssistantContent = "You are a helpful assistant."
AssistantRole = "assistant"
GPTModel = "gpt-3.5-turbo"
SystemRole = "system"
URL = "https://api.openai.com/v1/chat/completions"
UserRole = "user"
)
type Client struct {
caller http.Caller
caller http.Caller
readWriter history.Store
history []types.Message
}
func New(caller http.Caller) *Client {
return &Client{caller: caller}
func New(caller http.Caller, rw history.Store) *Client {
return &Client{
caller: caller,
readWriter: rw,
}
}
// Query sends a query to the API and returns the response as a string.
@@ -29,7 +38,9 @@ func New(caller http.Caller) *Client {
// call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) {
body, err := CreateBody(input, false)
c.initHistory(input)
body, err := c.createBody(false)
if err != nil {
return "", err
}
@@ -52,6 +63,8 @@ func (c *Client) Query(input string) (string, error) {
return "", errors.New("no responses returned")
}
c.updateHistory(response.Choices[0].Message.Content)
return response.Choices[0].Message.Content, nil
}
@@ -61,35 +74,59 @@ func (c *Client) Query(input string) (string, error) {
// input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error {
body, err := CreateBody(input, true)
c.initHistory(input)
body, err := c.createBody(true)
if err != nil {
return err
}
_, err = c.caller.Post(URL, body, true)
result, err := c.caller.Post(URL, body, true)
if err != nil {
return err
}
c.updateHistory(string(result))
return nil
}
func CreateBody(query string, stream bool) ([]byte, error) {
message := types.Message{
Role: role,
Content: query,
}
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.Request{
Model: model,
Messages: []types.Message{message},
Model: GPTModel,
Messages: c.history,
Stream: stream,
}
result, err := json.Marshal(body)
if err != nil {
return nil, err
return json.Marshal(body)
}
func (c *Client) initHistory(query string) {
message := types.Message{
Role: UserRole,
Content: query,
}
return result, nil
c.history, _ = c.readWriter.Read()
if len(c.history) == 0 {
c.history = []types.Message{{
Role: SystemRole,
Content: AssistantContent,
}}
}
// TODO Write history specific tests
// TODO Write delete-specific tests (on store)
// TODO Test the string returned from Stream
// TODO implement sliding window
c.history = append(c.history, message)
}
func (c *Client) updateHistory(response string) {
c.history = append(c.history, types.Message{
Role: AssistantRole,
Content: response,
})
_ = c.readWriter.Write(c.history)
}

View File

@@ -14,11 +14,13 @@ import (
"github.com/sclevine/spec/report"
)
//go:generate mockgen -destination=mocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/http Caller
//go:generate mockgen -destination=iomocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
var (
mockCtrl *gomock.Controller
mockCaller *MockCaller
mockStore *MockStore
subject *client.Client
)
@@ -31,8 +33,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
RegisterTestingT(t)
mockCtrl = gomock.NewController(t)
mockCaller = NewMockCaller(mockCtrl)
mockStore = NewMockStore(mockCtrl)
subject = client.New(mockCaller)
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
subject = client.New(mockCaller, mockStore)
})
it.After(func() {
@@ -43,12 +47,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
const query = "test query"
var (
err error
body []byte
body []byte
messages []types.Message
err error
)
it.Before(func() {
body, err = client.CreateBody(query, false)
messages = createMessages(nil, query)
body, err = createBody(messages)
Expect(err).NotTo(HaveOccurred())
})
@@ -56,14 +62,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
errorMsg := "error message"
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
_, err = subject.Query(query)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
_, err = subject.Query(query)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("empty response"))
})
@@ -71,7 +77,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
_, err = subject.Query(query)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
})
@@ -93,10 +99,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err.Error()).To(Equal("no responses returned"))
})
it("parses a valid http response as expected", func() {
const answer = "content"
choice := types.Choice{
Message: types.Message{
Role: "role",
Content: "content",
Content: answer,
},
FinishReason: "",
Index: 0,
@@ -113,9 +121,42 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
mockStore.EXPECT().Write(append(messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
result, err := subject.Query(query)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("content"))
Expect(result).To(Equal(answer))
})
})
}
func createBody(messages []types.Message) ([]byte, error) {
req := types.Request{
Model: client.GPTModel,
Messages: messages,
Stream: false,
}
return json.Marshal(req)
}
func createMessages(history []types.Message, query string) []types.Message {
var messages []types.Message
if len(history) == 0 {
messages = append(messages, types.Message{
Role: client.SystemRole,
Content: client.AssistantContent,
})
}
messages = append(messages, types.Message{
Role: client.UserRole,
Content: query,
})
return messages
}

78
client/iomocks_test.go Normal file
View File

@@ -0,0 +1,78 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/kardolus/chatgpt-cli/history (interfaces: Store)
// Package client_test is a generated GoMock package.
package client_test
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
types "github.com/kardolus/chatgpt-cli/types"
)
// MockStore is a mock of Store interface.
type MockStore struct {
ctrl *gomock.Controller
recorder *MockStoreMockRecorder
}
// MockStoreMockRecorder is the mock recorder for MockStore.
type MockStoreMockRecorder struct {
mock *MockStore
}
// NewMockStore creates a new mock instance.
func NewMockStore(ctrl *gomock.Controller) *MockStore {
mock := &MockStore{ctrl: ctrl}
mock.recorder = &MockStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockStore) Delete() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete")
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockStoreMockRecorder) Delete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete))
}
// Read mocks base method.
func (m *MockStore) Read() ([]types.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read")
ret0, _ := ret[0].([]types.Message)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockStoreMockRecorder) Read() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read))
}
// Write mocks base method.
func (m *MockStore) Write(arg0 []types.Message) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Write indicates an expected call of Write.
func (mr *MockStoreMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStore)(nil).Write), arg0)
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/kardolus/chatgpt-cli/client"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/http"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@@ -14,6 +15,7 @@ import (
const secretEnv = "OPENAI_API_KEY"
var queryMode bool
var clearHistory bool
func main() {
var rootCmd = &cobra.Command{
@@ -24,6 +26,7 @@ func main() {
}
rootCmd.PersistentFlags().BoolVarP(&queryMode, "query", "q", false, "Use query mode instead of stream mode")
rootCmd.PersistentFlags().BoolVarP(&clearHistory, "clear-history", "c", false, "Clear the history of ChatGPT CLI")
viper.AutomaticEnv()
@@ -34,15 +37,28 @@ func main() {
}
func run(cmd *cobra.Command, args []string) error {
if clearHistory {
historyHandler := history.New()
err := historyHandler.Delete()
if err != nil {
return err
}
fmt.Println("History successfully cleared.")
}
if len(args) == 0 {
return errors.New("you must specify your query")
if clearHistory {
return nil
} else {
return errors.New("you must specify your query")
}
}
secret := viper.GetString(secretEnv)
if secret == "" {
return errors.New("missing environment variable: " + secretEnv)
}
client := client.New(http.New().WithSecret(secret))
client := client.New(http.New().WithSecret(secret), history.New())
if queryMode {
result, err := client.Query(strings.Join(args, " "))
@@ -55,6 +71,5 @@ func run(cmd *cobra.Command, args []string) error {
return err
}
}
return nil
}

85
history/store.go Normal file
View File

@@ -0,0 +1,85 @@
package history
import (
"encoding/json"
"github.com/kardolus/chatgpt-cli/types"
"io/ioutil"
"os"
"path/filepath"
)
type Store interface {
Delete() error
Read() ([]types.Message, error)
Write([]types.Message) error
}
// Ensure RestCaller implements Caller interface
var _ Store = &FileIO{}
type FileIO struct {
}
func New() *FileIO {
return &FileIO{}
}
func (f *FileIO) Delete() error {
historyFilePath, err := getPath()
if err != nil {
return err
}
if _, err := os.Stat(historyFilePath); err == nil {
return os.Remove(historyFilePath)
}
return nil
}
func (f *FileIO) Read() ([]types.Message, error) {
historyFilePath, err := getPath()
if err != nil {
return nil, err
}
return parseFile(historyFilePath)
}
func (f *FileIO) Write(messages []types.Message) error {
historyFilePath, err := getPath()
if err != nil {
return err
}
data, err := json.Marshal(messages)
if err != nil {
return err
}
return ioutil.WriteFile(historyFilePath, data, 0644)
}
func getPath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(homeDir, ".chatgpt-cli", "history"), nil
}
func parseFile(fileName string) ([]types.Message, error) {
var result []types.Message
buf, err := ioutil.ReadFile(fileName)
if err != nil {
return nil, err
}
if err := json.Unmarshal(buf, &result); err != nil {
return nil, err
}
return result, nil
}

View File

@@ -55,8 +55,7 @@ func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error)
if response.StatusCode >= 200 && response.StatusCode < 300 {
if stream {
ProcessResponse(response.Body, os.Stdout)
return nil, nil
return ProcessResponse(response.Body, os.Stdout), nil
} else {
result, err := ioutil.ReadAll(response.Body)
if err != nil {
@@ -69,7 +68,9 @@ func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error)
return nil, fmt.Errorf("http error: %d", response.StatusCode)
}
func ProcessResponse(r io.Reader, w io.Writer) {
func ProcessResponse(r io.Reader, w io.Writer) []byte {
var result []byte
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
@@ -81,6 +82,7 @@ func ProcessResponse(r io.Reader, w io.Writer) {
if line == "[DONE]" {
_, _ = w.Write([]byte("\n"))
result = append(result, []byte("\n")...)
break
}
@@ -94,8 +96,10 @@ func ProcessResponse(r io.Reader, w io.Writer) {
for _, choice := range data.Choices {
if content, ok := choice.Delta["content"]; ok {
_, _ = w.Write([]byte(content))
result = append(result, []byte(content)...)
}
}
}
}
return result
}