Add streaming

This commit is contained in:
kardolus
2023-05-02 10:14:38 -04:00
parent 0376a9c736
commit 7704fbc935
8 changed files with 160 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
# ChatGPT CLI
A Proof of Concept (POC) for building ChatGPT clients.
This project is a Proof of Concept (POC) that demonstrates how to build ChatGPT clients with streaming support in a
Command-Line Interface (CLI) environment.
![a screenshot](resources/screenshot.png)
@@ -30,7 +31,7 @@ building the application:
## Up Next
* Stream output rather than having text blobs
* Add a command line flag for non-streaming
* Use Viper for command line parsing
## Links

View File

@@ -22,13 +22,19 @@ func New(caller http.Caller) *Client {
return &Client{caller: caller}
}
// Query sends a query to the API and returns the response as a string.
// It takes an input string as a parameter and returns a string containing
// the API response or an error if there's any issue during the process.
// The method creates a request body with the input and then makes an API
// call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) {
body, err := CreateBody(input)
body, err := CreateBody(input, false)
if err != nil {
return "", err
}
raw, err := c.caller.Post(URL, body)
raw, err := c.caller.Post(URL, body, false)
if err != nil {
return "", err
}
@@ -49,7 +55,26 @@ func (c *Client) Query(input string) (string, error) {
return response.Choices[0].Message.Content, nil
}
func CreateBody(query string) ([]byte, error) {
// Stream sends a query to the API and processes the response as a stream.
// It takes an input string as a parameter and returns an error if there's
// any issue during the process. The method creates a request body with the
// input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error {
body, err := CreateBody(input, true)
if err != nil {
return err
}
_, err = c.caller.Post(URL, body, true)
if err != nil {
return err
}
return nil
}
func CreateBody(query string, stream bool) ([]byte, error) {
message := types.Message{
Role: role,
Content: query,
@@ -58,6 +83,7 @@ func CreateBody(query string) ([]byte, error) {
body := types.Request{
Model: model,
Messages: []types.Message{message},
Stream: stream,
}
result, err := json.Marshal(body)

View File

@@ -48,28 +48,28 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
)
it.Before(func() {
body, err = client.CreateBody(query)
body, err = client.CreateBody(query, false)
Expect(err).NotTo(HaveOccurred())
})
it("throws an error when the http callout fails", func() {
errorMsg := "error message"
mockCaller.EXPECT().Post(client.URL, body).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
mockCaller.EXPECT().Post(client.URL, body).Return(nil, nil)
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("empty response"))
})
it("throws an error when the response is a malformed json", func() {
malformed := "{no"
mockCaller.EXPECT().Post(client.URL, body).Return([]byte(malformed), nil)
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -86,7 +86,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, body).Return(respBytes, nil)
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -111,7 +111,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, body).Return(respBytes, nil)
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
result, err := subject.Query(query)
Expect(err).NotTo(HaveOccurred())

View File

@@ -34,16 +34,16 @@ func (m *MockCaller) EXPECT() *MockCallerMockRecorder {
}
// Post mocks base method.
func (m *MockCaller) Post(arg0 string, arg1 []byte) ([]byte, error) {
func (m *MockCaller) Post(arg0 string, arg1 []byte, arg2 bool) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Post", arg0, arg1)
ret := m.ctrl.Call(m, "Post", arg0, arg1, arg2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Post indicates an expected call of Post.
func (mr *MockCallerMockRecorder) Post(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockCallerMockRecorder) Post(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1, arg2)
}

View File

@@ -2,7 +2,6 @@ package main
import (
"errors"
"fmt"
"github.com/kardolus/chatgpt-poc/client"
"github.com/kardolus/chatgpt-poc/http"
"log"
@@ -35,11 +34,9 @@ func run() error {
}
client := client.New(http.New().WithSecret(secret))
result, err := client.Query(strings.Join(os.Args[1:], " "))
if err != nil {
if err := client.Stream(strings.Join(os.Args[1:], " ")); err != nil {
return err
}
fmt.Println(result)
return nil
}

View File

@@ -1,14 +1,20 @@
package http
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/kardolus/chatgpt-poc/types"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
)
type Caller interface {
Post(url string, body []byte) ([]byte, error)
Post(url string, body []byte, stream bool) ([]byte, error)
}
type RestCaller struct {
@@ -30,7 +36,7 @@ func (r *RestCaller) WithSecret(secret string) *RestCaller {
return r
}
func (r *RestCaller) Post(url string, body []byte) ([]byte, error) {
func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error) {
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
@@ -48,12 +54,48 @@ func (r *RestCaller) Post(url string, body []byte) ([]byte, error) {
defer response.Body.Close()
if response.StatusCode >= 200 && response.StatusCode < 300 {
result, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
if stream {
ProcessResponse(response.Body, os.Stdout)
return nil, nil
} else {
result, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
return result, nil
}
return result, nil
}
return nil, fmt.Errorf("http error: %d", response.StatusCode)
}
func ProcessResponse(r io.Reader, w io.Writer) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data:") {
line = line[6:] // Skip the "data: " prefix
if len(line) < 6 {
continue
}
if line == "[DONE]" {
_, _ = w.Write([]byte("\n"))
break
}
var data types.Data
err := json.Unmarshal([]byte(line), &data)
if err != nil {
_, _ = fmt.Fprintf(w, "Error: %s\n", err.Error())
continue
}
for _, choice := range data.Choices {
if content, ok := choice.Delta["content"]; ok {
_, _ = w.Write([]byte(content))
}
}
}
}
}

55
http/http_test.go Normal file
View File

@@ -0,0 +1,55 @@
package http_test
import (
"bytes"
"github.com/kardolus/chatgpt-poc/http"
"strings"
"testing"
. "github.com/onsi/gomega"
"github.com/sclevine/spec"
"github.com/sclevine/spec/report"
)
func TestUnitHTTP(t *testing.T) {
spec.Run(t, "Testing the HTTP Client", testHTTP, spec.Report(report.Terminal{}))
}
func testHTTP(t *testing.T, when spec.G, it spec.S) {
it.Before(func() {
RegisterTestingT(t)
})
when("ProcessResponse()", func() {
it("parses a stream as expected", func() {
buf := &bytes.Buffer{}
http.ProcessResponse(strings.NewReader(stream), buf)
output := buf.String()
Expect(output).To(Equal("a b c\n"))
})
it("throws an error when the json is invalid", func() {
input := `data: {"invalid":"json"` // missing closing brace
expectedOutput := "Error: unexpected end of JSON input\n"
var buf bytes.Buffer
http.ProcessResponse(strings.NewReader(input), &buf)
output := buf.String()
Expect(output).To(Equal(expectedOutput))
})
})
}
const stream = `
data: {"id":"id-1","object":"chat.completion.chunk","created":1,"model":"model-1","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
data: {"id":"id-2","object":"chat.completion.chunk","created":2,"model":"model-1","choices":[{"delta":{"content":"a"},"index":0,"finish_reason":null}]}
data: {"id":"id-3","object":"chat.completion.chunk","created":3,"model":"model-1","choices":[{"delta":{"content":" b"},"index":0,"finish_reason":null}]}
data: {"id":"id-4","object":"chat.completion.chunk","created":4,"model":"model-1","choices":[{"delta":{"content":" c"},"index":0,"finish_reason":null}]}
data: {"id":"id-5","object":"chat.completion.chunk","created":5,"model":"model-1","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
data: [DONE]
`

View File

@@ -3,6 +3,7 @@ package types
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
}
type Message struct {
@@ -28,3 +29,15 @@ type Choice struct {
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
}
type Data struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Delta map[string]string `json:"delta"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}