Add streaming

This commit is contained in:
kardolus
2023-05-02 10:14:38 -04:00
parent 0376a9c736
commit 7704fbc935
8 changed files with 160 additions and 26 deletions

View File

@@ -22,13 +22,19 @@ func New(caller http.Caller) *Client {
return &Client{caller: caller}
}
// Query sends a query to the API and returns the response as a string.
// It takes an input string as a parameter and returns a string containing
// the API response or an error if there's any issue during the process.
// The method creates a request body with the input and then makes an API
// call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) {
body, err := CreateBody(input)
body, err := CreateBody(input, false)
if err != nil {
return "", err
}
raw, err := c.caller.Post(URL, body)
raw, err := c.caller.Post(URL, body, false)
if err != nil {
return "", err
}
@@ -49,7 +55,26 @@ func (c *Client) Query(input string) (string, error) {
return response.Choices[0].Message.Content, nil
}
func CreateBody(query string) ([]byte, error) {
// Stream sends a query to the API and processes the response as a stream.
// It takes an input string as a parameter and returns an error if there's
// any issue during the process. The method creates a request body with the
// input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error {
body, err := CreateBody(input, true)
if err != nil {
return err
}
_, err = c.caller.Post(URL, body, true)
if err != nil {
return err
}
return nil
}
func CreateBody(query string, stream bool) ([]byte, error) {
message := types.Message{
Role: role,
Content: query,
@@ -58,6 +83,7 @@ func CreateBody(query string) ([]byte, error) {
body := types.Request{
Model: model,
Messages: []types.Message{message},
Stream: stream,
}
result, err := json.Marshal(body)

View File

@@ -48,28 +48,28 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
)
it.Before(func() {
body, err = client.CreateBody(query)
body, err = client.CreateBody(query, false)
Expect(err).NotTo(HaveOccurred())
})
it("throws an error when the http callout fails", func() {
errorMsg := "error message"
mockCaller.EXPECT().Post(client.URL, body).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
mockCaller.EXPECT().Post(client.URL, body).Return(nil, nil)
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("empty response"))
})
it("throws an error when the response is a malformed json", func() {
malformed := "{no"
mockCaller.EXPECT().Post(client.URL, body).Return([]byte(malformed), nil)
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -86,7 +86,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, body).Return(respBytes, nil)
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -111,7 +111,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, body).Return(respBytes, nil)
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
result, err := subject.Query(query)
Expect(err).NotTo(HaveOccurred())

View File

@@ -34,16 +34,16 @@ func (m *MockCaller) EXPECT() *MockCallerMockRecorder {
}
// Post mocks base method.
func (m *MockCaller) Post(arg0 string, arg1 []byte) ([]byte, error) {
func (m *MockCaller) Post(arg0 string, arg1 []byte, arg2 bool) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Post", arg0, arg1)
ret := m.ctrl.Call(m, "Post", arg0, arg1, arg2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Post indicates an expected call of Post.
func (mr *MockCallerMockRecorder) Post(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockCallerMockRecorder) Post(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1, arg2)
}