mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add streaming
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# ChatGPT CLI
|
||||
|
||||
A Proof of Concept (POC) for building ChatGPT clients.
|
||||
This project is a Proof of Concept (POC) that demonstrates how to build ChatGPT clients with streaming support in a
|
||||
Command-Line Interface (CLI) environment.
|
||||
|
||||

|
||||
|
||||
@@ -30,7 +31,7 @@ building the application:
|
||||
|
||||
## Up Next
|
||||
|
||||
* Stream output rather than having text blobs
|
||||
* Add a command line flag for non-streaming
|
||||
* Use Viper for command line parsing
|
||||
|
||||
## Links
|
||||
|
||||
@@ -22,13 +22,19 @@ func New(caller http.Caller) *Client {
|
||||
return &Client{caller: caller}
|
||||
}
|
||||
|
||||
// Query sends a query to the API and returns the response as a string.
|
||||
// It takes an input string as a parameter and returns a string containing
|
||||
// the API response or an error if there's any issue during the process.
|
||||
// The method creates a request body with the input and then makes an API
|
||||
// call using the Post method. If the response is not empty, it decodes the
|
||||
// response JSON and returns the content of the first choice.
|
||||
func (c *Client) Query(input string) (string, error) {
|
||||
body, err := CreateBody(input)
|
||||
body, err := CreateBody(input, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
raw, err := c.caller.Post(URL, body)
|
||||
raw, err := c.caller.Post(URL, body, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -49,7 +55,26 @@ func (c *Client) Query(input string) (string, error) {
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func CreateBody(query string) ([]byte, error) {
|
||||
// Stream sends a query to the API and processes the response as a stream.
|
||||
// It takes an input string as a parameter and returns an error if there's
|
||||
// any issue during the process. The method creates a request body with the
|
||||
// input and then makes an API call using the Post method. The actual
|
||||
// processing of the streamed response is done in the Post method.
|
||||
func (c *Client) Stream(input string) error {
|
||||
body, err := CreateBody(input, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.caller.Post(URL, body, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateBody(query string, stream bool) ([]byte, error) {
|
||||
message := types.Message{
|
||||
Role: role,
|
||||
Content: query,
|
||||
@@ -58,6 +83,7 @@ func CreateBody(query string) ([]byte, error) {
|
||||
body := types.Request{
|
||||
Model: model,
|
||||
Messages: []types.Message{message},
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(body)
|
||||
|
||||
@@ -48,28 +48,28 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
)
|
||||
|
||||
it.Before(func() {
|
||||
body, err = client.CreateBody(query)
|
||||
body, err = client.CreateBody(query, false)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
it("throws an error when the http callout fails", func() {
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.URL, body).Return(nil, errors.New(errorMsg))
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(errorMsg))
|
||||
})
|
||||
it("throws an error when the response is empty", func() {
|
||||
mockCaller.EXPECT().Post(client.URL, body).Return(nil, nil)
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("empty response"))
|
||||
})
|
||||
it("throws an error when the response is a malformed json", func() {
|
||||
malformed := "{no"
|
||||
mockCaller.EXPECT().Post(client.URL, body).Return([]byte(malformed), nil)
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -86,7 +86,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.URL, body).Return(respBytes, nil)
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -111,7 +111,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.URL, body).Return(respBytes, nil)
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
|
||||
|
||||
result, err := subject.Query(query)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
@@ -34,16 +34,16 @@ func (m *MockCaller) EXPECT() *MockCallerMockRecorder {
|
||||
}
|
||||
|
||||
// Post mocks base method.
|
||||
func (m *MockCaller) Post(arg0 string, arg1 []byte) ([]byte, error) {
|
||||
func (m *MockCaller) Post(arg0 string, arg1 []byte, arg2 bool) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Post", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "Post", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Post indicates an expected call of Post.
|
||||
func (mr *MockCallerMockRecorder) Post(arg0, arg1 interface{}) *gomock.Call {
|
||||
func (mr *MockCallerMockRecorder) Post(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kardolus/chatgpt-poc/client"
|
||||
"github.com/kardolus/chatgpt-poc/http"
|
||||
"log"
|
||||
@@ -35,11 +34,9 @@ func run() error {
|
||||
}
|
||||
client := client.New(http.New().WithSecret(secret))
|
||||
|
||||
result, err := client.Query(strings.Join(os.Args[1:], " "))
|
||||
if err != nil {
|
||||
if err := client.Stream(strings.Join(os.Args[1:], " ")); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(result)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
54
http/http.go
54
http/http.go
@@ -1,14 +1,20 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/kardolus/chatgpt-poc/types"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Caller interface {
|
||||
Post(url string, body []byte) ([]byte, error)
|
||||
Post(url string, body []byte, stream bool) ([]byte, error)
|
||||
}
|
||||
|
||||
type RestCaller struct {
|
||||
@@ -30,7 +36,7 @@ func (r *RestCaller) WithSecret(secret string) *RestCaller {
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RestCaller) Post(url string, body []byte) ([]byte, error) {
|
||||
func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
@@ -48,12 +54,48 @@ func (r *RestCaller) Post(url string, body []byte) ([]byte, error) {
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode >= 200 && response.StatusCode < 300 {
|
||||
result, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
if stream {
|
||||
ProcessResponse(response.Body, os.Stdout)
|
||||
return nil, nil
|
||||
} else {
|
||||
result, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("http error: %d", response.StatusCode)
|
||||
}
|
||||
|
||||
func ProcessResponse(r io.Reader, w io.Writer) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
line = line[6:] // Skip the "data: " prefix
|
||||
if len(line) < 6 {
|
||||
continue
|
||||
}
|
||||
|
||||
if line == "[DONE]" {
|
||||
_, _ = w.Write([]byte("\n"))
|
||||
break
|
||||
}
|
||||
|
||||
var data types.Data
|
||||
err := json.Unmarshal([]byte(line), &data)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(w, "Error: %s\n", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, choice := range data.Choices {
|
||||
if content, ok := choice.Delta["content"]; ok {
|
||||
_, _ = w.Write([]byte(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
55
http/http_test.go
Normal file
55
http/http_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/kardolus/chatgpt-poc/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sclevine/spec"
|
||||
"github.com/sclevine/spec/report"
|
||||
)
|
||||
|
||||
func TestUnitHTTP(t *testing.T) {
|
||||
spec.Run(t, "Testing the HTTP Client", testHTTP, spec.Report(report.Terminal{}))
|
||||
}
|
||||
|
||||
func testHTTP(t *testing.T, when spec.G, it spec.S) {
|
||||
it.Before(func() {
|
||||
RegisterTestingT(t)
|
||||
})
|
||||
|
||||
when("ProcessResponse()", func() {
|
||||
it("parses a stream as expected", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
http.ProcessResponse(strings.NewReader(stream), buf)
|
||||
output := buf.String()
|
||||
Expect(output).To(Equal("a b c\n"))
|
||||
})
|
||||
it("throws an error when the json is invalid", func() {
|
||||
input := `data: {"invalid":"json"` // missing closing brace
|
||||
expectedOutput := "Error: unexpected end of JSON input\n"
|
||||
|
||||
var buf bytes.Buffer
|
||||
http.ProcessResponse(strings.NewReader(input), &buf)
|
||||
output := buf.String()
|
||||
|
||||
Expect(output).To(Equal(expectedOutput))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const stream = `
|
||||
data: {"id":"id-1","object":"chat.completion.chunk","created":1,"model":"model-1","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"id-2","object":"chat.completion.chunk","created":2,"model":"model-1","choices":[{"delta":{"content":"a"},"index":0,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"id-3","object":"chat.completion.chunk","created":3,"model":"model-1","choices":[{"delta":{"content":" b"},"index":0,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"id-4","object":"chat.completion.chunk","created":4,"model":"model-1","choices":[{"delta":{"content":" c"},"index":0,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"id-5","object":"chat.completion.chunk","created":5,"model":"model-1","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]
|
||||
`
|
||||
@@ -3,6 +3,7 @@ package types
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
@@ -28,3 +29,15 @@ type Choice struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Delta map[string]string `json:"delta"`
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user