Add tests for Stream()

This commit is contained in:
kardolus
2023-05-03 10:49:12 -04:00
parent eae488c3b1
commit 0d50e666a5
2 changed files with 73 additions and 8 deletions

View File

@@ -115,8 +115,6 @@ func (c *Client) initHistory(query string) {
}}
}
// TODO Test the string returned from Stream
// TODO Write delete-specific tests (on store)
// TODO implement sliding window
c.history = append(c.history, message)

View File

@@ -29,6 +29,8 @@ func TestUnitClient(t *testing.T) {
}
func testClient(t *testing.T, when spec.G, it spec.S) {
const query = "test query"
it.Before(func() {
RegisterTestingT(t)
mockCtrl = gomock.NewController(t)
@@ -42,8 +44,6 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
when("Query()", func() {
const query = "test query"
var (
body []byte
messages []types.Message
@@ -52,7 +52,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
it.Before(func() {
messages = createMessages(nil, query)
body, err = createBody(messages)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
})
@@ -159,7 +159,74 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
},
}
messages = createMessages(history, query)
body, err = createBody(messages)
body, err = createBody(messages, false)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
})
})
})
when("Stream()", func() {
var (
body []byte
messages []types.Message
err error
)
it.Before(func() {
messages = createMessages(nil, query)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
})
it("throws an error when the http callout fails", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
errorMsg := "error message"
mockCaller.EXPECT().Post(client.URL, body, true).Return(nil, errors.New(errorMsg))
err := subject.Stream(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
when("a valid http response is received", func() {
const answer = "answer"
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
mockStore.EXPECT().Read().Return(history, nil).Times(1)
mockCaller.EXPECT().Post(client.URL, expectedBody, true).Return([]byte(answer), nil)
messages = createMessages(history, query)
mockStore.EXPECT().Write(append(messages, types.Message{
Role: client.AssistantRole,
Content: answer,
}))
err := subject.Stream(query)
Expect(err).NotTo(HaveOccurred())
}
it("returns the expected result for an empty history", func() {
testValidHTTPResponse(nil, body)
})
it("returns the expected result for a non-empty history", func() {
history := []types.Message{
{
Role: client.SystemRole,
Content: client.AssistantContent,
},
{
Role: client.UserRole,
Content: "question 1",
},
{
Role: client.AssistantRole,
Content: "answer 1",
},
}
messages = createMessages(history, query)
body, err = createBody(messages, true)
Expect(err).NotTo(HaveOccurred())
testValidHTTPResponse(history, body)
@@ -168,11 +235,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
}
func createBody(messages []types.Message) ([]byte, error) {
func createBody(messages []types.Message, stream bool) ([]byte, error) {
req := types.Request{
Model: client.GPTModel,
Messages: messages,
Stream: false,
Stream: stream,
}
return json.Marshal(req)