Refactor Query function to include token usage in response

This commit is contained in:
kardolus
2024-04-20 12:08:33 -04:00
parent 154f0331a4
commit ca2e544603
5 changed files with 64 additions and 40 deletions

View File

@@ -104,37 +104,35 @@ func (c *Client) ProvideContext(context string) {
c.History = append(c.History, messages...)
}
// Query sends a query to the API and returns the response as a string.
// It takes an input string as a parameter and returns a string containing
// the API response or an error if there's any issue during the process.
// The method creates a request body with the input and then makes an API
// call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) {
// Query sends a query to the API, returning the response as a string along with the token usage.
// It takes an input string, constructs a request body, and makes a POST API call.
// Returns the API response string, the number of tokens used, and an error if any issues occur.
// If the response contains choices, it decodes the JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, int, error) {
c.prepareQuery(input)
body, err := c.createBody(false)
if err != nil {
return "", err
return "", 0, err
}
raw, err := c.caller.Post(c.getEndpoint(c.Config.CompletionsPath), body, false)
if err != nil {
return "", err
return "", 0, err
}
var response types.CompletionsResponse
if err := c.processResponse(raw, &response); err != nil {
return "", err
return "", 0, err
}
if len(response.Choices) == 0 {
return "", errors.New("no responses returned")
return "", response.Usage.TotalTokens, errors.New("no responses returned")
}
c.updateHistory(response.Choices[0].Message.Content)
return response.Choices[0].Message.Content, nil
return response.Choices[0].Message.Content, response.Usage.TotalTokens, nil
}
// Stream sends a query to the API and processes the response as a stream.

View File

@@ -153,7 +153,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, body, false).Return(respBytes, tt.postError)
_, err = subject.Query(query)
_, _, err = subject.Query(query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(tt.expectedError))
})
@@ -161,7 +161,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
when("a valid http response is received", func() {
testValidHTTPResponse := func(subject *client.Client, history []types.Message, expectedBody []byte, omitHistory bool) {
const answer = "content"
const (
answer = "content"
tokens = 789
)
choice := types.Choice{
Message: types.Message{
@@ -177,6 +180,11 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Created: 0,
Model: subject.Config.Model,
Choices: []types.Choice{choice},
Usage: types.Usage{
PromptTokens: 123,
CompletionTokens: 456,
TotalTokens: tokens,
},
}
respBytes, err := json.Marshal(response)
@@ -194,9 +202,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
}))
}
result, err := subject.Query(query)
result, usage, err := subject.Query(query)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(answer))
Expect(usage).To(Equal(tokens))
}
it("uses the values specified by the configuration instead of the default values", func() {
const (