mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add ListModels
This commit is contained in:
@@ -33,6 +33,21 @@ func (m *MockCaller) EXPECT() *MockCallerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockCaller) Get(arg0 string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockCallerMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCaller)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// Post mocks base method.
|
||||
func (m *MockCaller) Post(arg0 string, arg1 []byte, arg2 bool) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
106
client/client.go
106
client/client.go
@@ -14,12 +14,15 @@ import (
|
||||
const (
|
||||
AssistantContent = "You are a helpful assistant."
|
||||
AssistantRole = "assistant"
|
||||
GPTModel = "gpt-3.5-turbo"
|
||||
ErrEmptyResponse = "empty response"
|
||||
DefaultGPTModel = "gpt-3.5-turbo"
|
||||
MaxTokenBufferPercentage = 20
|
||||
MaxTokenSize = 4096
|
||||
SystemRole = "system"
|
||||
URL = "https://api.openai.com/v1/chat/completions"
|
||||
CompletionURL = "https://api.openai.com/v1/chat/completions"
|
||||
ModelURL = "https://api.openai.com/v1/models"
|
||||
UserRole = "user"
|
||||
gptPrefix = "gpt"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@@ -35,7 +38,7 @@ func New(caller http.Caller, rw history.Store) *Client {
|
||||
caller: caller,
|
||||
readWriter: rw,
|
||||
capacity: MaxTokenSize,
|
||||
model: GPTModel,
|
||||
model: DefaultGPTModel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +52,51 @@ func (c *Client) WithModel(model string) *Client {
|
||||
return c
|
||||
}
|
||||
|
||||
// ListModels retrieves a list of all available models from the OpenAI API.
|
||||
// The models are returned as a slice of strings, each entry representing a model ID.
|
||||
// Models that have an ID starting with 'gpt' are included.
|
||||
// The currently active model is marked with an asterisk (*) in the list.
|
||||
// In case of an error during the retrieval or processing of the models,
|
||||
// the method returns an error. If the API response is empty, an error is returned as well.
|
||||
func (c *Client) ListModels() ([]string, error) {
|
||||
var result []string
|
||||
|
||||
raw, err := c.caller.Get(ModelURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response types.ListModelsResponse
|
||||
if err := c.processResponse(raw, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, model := range response.Data {
|
||||
if strings.HasPrefix(model.Id, gptPrefix) {
|
||||
if model.Id != DefaultGPTModel {
|
||||
result = append(result, fmt.Sprintf("- %s", model.Id))
|
||||
continue
|
||||
}
|
||||
result = append(result, fmt.Sprintf("* %s (current)", model.Id))
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ProvideContext adds custom context to the client's history by converting the
|
||||
// provided string into a series of messages. This allows the ChatGPT API to have
|
||||
// prior knowledge of the provided context when generating responses.
|
||||
//
|
||||
// The context string should contain the text you want to provide as context,
|
||||
// and the method will split it into messages, preserving punctuation and special
|
||||
// characters.
|
||||
func (c *Client) ProvideContext(context string) {
|
||||
c.initHistory()
|
||||
messages := createMessagesFromString(context)
|
||||
c.History = append(c.History, messages...)
|
||||
}
|
||||
|
||||
// Query sends a query to the API and returns the response as a string.
|
||||
// It takes an input string as a parameter and returns a string containing
|
||||
// the API response or an error if there's any issue during the process.
|
||||
@@ -56,26 +104,21 @@ func (c *Client) WithModel(model string) *Client {
|
||||
// call using the Post method. If the response is not empty, it decodes the
|
||||
// response JSON and returns the content of the first choice.
|
||||
func (c *Client) Query(input string) (string, error) {
|
||||
c.initHistory()
|
||||
c.addQuery(input)
|
||||
c.prepareQuery(input)
|
||||
|
||||
body, err := c.createBody(false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
raw, err := c.caller.Post(URL, body, false)
|
||||
raw, err := c.caller.Post(CompletionURL, body, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if raw == nil {
|
||||
return "", errors.New("empty response")
|
||||
}
|
||||
|
||||
var response types.Response
|
||||
if err := json.Unmarshal(raw, &response); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
var response types.CompletionsResponse
|
||||
if err := c.processResponse(raw, &response); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
@@ -93,15 +136,14 @@ func (c *Client) Query(input string) (string, error) {
|
||||
// input and then makes an API call using the Post method. The actual
|
||||
// processing of the streamed response is done in the Post method.
|
||||
func (c *Client) Stream(input string) error {
|
||||
c.initHistory()
|
||||
c.addQuery(input)
|
||||
c.prepareQuery(input)
|
||||
|
||||
body, err := c.createBody(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := c.caller.Post(URL, body, true)
|
||||
result, err := c.caller.Post(CompletionURL, body, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -111,21 +153,8 @@ func (c *Client) Stream(input string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProvideContext adds custom context to the client's history by converting the
|
||||
// provided string into a series of messages. This allows the ChatGPT API to have
|
||||
// prior knowledge of the provided context when generating responses.
|
||||
//
|
||||
// The context string should contain the text you want to provide as context,
|
||||
// and the method will split it into messages, preserving punctuation and special
|
||||
// characters.
|
||||
func (c *Client) ProvideContext(context string) {
|
||||
c.initHistory()
|
||||
messages := createMessagesFromString(context)
|
||||
c.History = append(c.History, messages...)
|
||||
}
|
||||
|
||||
func (c *Client) createBody(stream bool) ([]byte, error) {
|
||||
body := types.Request{
|
||||
body := types.CompletionsRequest{
|
||||
Messages: c.History,
|
||||
Model: c.model,
|
||||
Stream: stream,
|
||||
@@ -158,6 +187,23 @@ func (c *Client) addQuery(query string) {
|
||||
c.truncateHistory()
|
||||
}
|
||||
|
||||
func (c *Client) prepareQuery(input string) {
|
||||
c.initHistory()
|
||||
c.addQuery(input)
|
||||
}
|
||||
|
||||
func (c *Client) processResponse(raw []byte, v interface{}) error {
|
||||
if raw == nil {
|
||||
return errors.New(ErrEmptyResponse)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(raw, v); err != nil {
|
||||
return fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) truncateHistory() {
|
||||
tokens, rolling := countTokens(c.History)
|
||||
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
_ "github.com/golang/mock/mockgen/model"
|
||||
"github.com/kardolus/chatgpt-cli/client"
|
||||
"github.com/kardolus/chatgpt-cli/types"
|
||||
"github.com/kardolus/chatgpt-cli/utils"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -60,7 +61,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, errors.New(errorMsg))
|
||||
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -68,7 +69,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
})
|
||||
it("throws an error when the response is empty", func() {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, nil)
|
||||
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -78,7 +79,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return([]byte(malformed), nil)
|
||||
|
||||
_, err := subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -87,7 +88,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
it("throws an error when the response is missing Choices", func() {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
|
||||
response := &types.Response{
|
||||
response := &types.CompletionsResponse{
|
||||
ID: "id",
|
||||
Object: "object",
|
||||
Created: 0,
|
||||
@@ -97,7 +98,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, nil)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -117,19 +118,19 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
FinishReason: "",
|
||||
Index: 0,
|
||||
}
|
||||
response := &types.Response{
|
||||
response := &types.CompletionsResponse{
|
||||
ID: "id",
|
||||
Object: "object",
|
||||
Created: 0,
|
||||
Model: client.GPTModel,
|
||||
Model: client.DefaultGPTModel,
|
||||
Choices: []types.Choice{choice},
|
||||
}
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.URL, expectedBody, false).Return(respBytes, nil)
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, false).Return(respBytes, nil)
|
||||
|
||||
var request types.Request
|
||||
var request types.CompletionsRequest
|
||||
err = json.Unmarshal(expectedBody, &request)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
@@ -228,7 +229,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.URL, body, true).Return(nil, errors.New(errorMsg))
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg))
|
||||
|
||||
err := subject.Stream(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -239,7 +240,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
|
||||
mockStore.EXPECT().Read().Return(history, nil).Times(1)
|
||||
mockCaller.EXPECT().Post(client.URL, expectedBody, true).Return([]byte(answer), nil)
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil)
|
||||
|
||||
messages = createMessages(history, query)
|
||||
|
||||
@@ -278,6 +279,44 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
})
|
||||
})
|
||||
})
|
||||
when("ListModels()", func() {
|
||||
it("throws an error when the http callout fails", func() {
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg))
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal(errorMsg))
|
||||
})
|
||||
it("throws an error when the response is empty", func() {
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil)
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("empty response"))
|
||||
})
|
||||
it("throws an error when the response is a malformed json", func() {
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil)
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
|
||||
})
|
||||
it("filters gpt models as expected", func() {
|
||||
response, err := utils.FileToBytes("models.json")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(response, nil)
|
||||
|
||||
result, err := subject.ListModels()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(result).NotTo(BeEmpty())
|
||||
Expect(result).To(HaveLen(2))
|
||||
Expect(result[0]).To(Equal("* gpt-3.5-turbo (current)"))
|
||||
Expect(result[1]).To(Equal("- gpt-3.5-turbo-0301"))
|
||||
})
|
||||
})
|
||||
when("ProvideContext()", func() {
|
||||
it("updates the history with the provided context", func() {
|
||||
context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
|
||||
@@ -298,8 +337,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
}
|
||||
|
||||
func createBody(messages []types.Message, stream bool) ([]byte, error) {
|
||||
req := types.Request{
|
||||
Model: client.GPTModel,
|
||||
req := types.CompletionsRequest{
|
||||
Model: client.DefaultGPTModel,
|
||||
Messages: messages,
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user