Add ListModels

This commit is contained in:
kardolus
2023-05-21 09:26:29 -04:00
parent 8e4d995f07
commit 1df84cbaab
10 changed files with 1851 additions and 74 deletions

View File

@@ -34,6 +34,7 @@ environment, demonstrating its practicality and effectiveness.
* **Custom context from local files**: Provide custom context through piping for GPT model reference during
conversation.
* **Custom chat models**: Use a custom chat model by specifying the model name with the `-m` or `--model` flag.
* **Model listing**: Get a list of available models by using the `-l` or `--list-models` flag.
* **Viper integration**: Robust configuration management.
## Installation
@@ -130,6 +131,12 @@ Then, use the pipe feature to provide this context to ChatGPT:
cat context.txt | chatgpt "What kind of toy would Kya enjoy?"
```
6. To list all available models, use the -l or --list-models flag:
```shell
chatgpt --list-models
```
## Development
To start developing, set the `OPENAI_API_KEY` environment variable to

View File

@@ -33,6 +33,21 @@ func (m *MockCaller) EXPECT() *MockCallerMockRecorder {
return m.recorder
}
// Get mocks base method.
func (m *MockCaller) Get(arg0 string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockCallerMockRecorder) Get(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCaller)(nil).Get), arg0)
}
// Post mocks base method.
func (m *MockCaller) Post(arg0 string, arg1 []byte, arg2 bool) ([]byte, error) {
m.ctrl.T.Helper()

View File

@@ -14,12 +14,15 @@ import (
const (
AssistantContent = "You are a helpful assistant."
AssistantRole = "assistant"
GPTModel = "gpt-3.5-turbo"
ErrEmptyResponse = "empty response"
DefaultGPTModel = "gpt-3.5-turbo"
MaxTokenBufferPercentage = 20
MaxTokenSize = 4096
SystemRole = "system"
URL = "https://api.openai.com/v1/chat/completions"
CompletionURL = "https://api.openai.com/v1/chat/completions"
ModelURL = "https://api.openai.com/v1/models"
UserRole = "user"
gptPrefix = "gpt"
)
type Client struct {
@@ -35,7 +38,7 @@ func New(caller http.Caller, rw history.Store) *Client {
caller: caller,
readWriter: rw,
capacity: MaxTokenSize,
model: GPTModel,
model: DefaultGPTModel,
}
}
@@ -49,6 +52,51 @@ func (c *Client) WithModel(model string) *Client {
return c
}
// ListModels retrieves a list of all available models from the OpenAI API.
// The models are returned as a slice of strings, each entry representing a model ID.
// Models that have an ID starting with 'gpt' are included.
// The currently active model is marked with an asterisk (*) in the list.
// In case of an error during the retrieval or processing of the models,
// the method returns an error. If the API response is empty, an error is returned as well.
func (c *Client) ListModels() ([]string, error) {
var result []string
raw, err := c.caller.Get(ModelURL)
if err != nil {
return nil, err
}
var response types.ListModelsResponse
if err := c.processResponse(raw, &response); err != nil {
return nil, err
}
for _, model := range response.Data {
if strings.HasPrefix(model.Id, gptPrefix) {
if model.Id != DefaultGPTModel {
result = append(result, fmt.Sprintf("- %s", model.Id))
continue
}
result = append(result, fmt.Sprintf("* %s (current)", model.Id))
}
}
return result, nil
}
// ProvideContext adds custom context to the client's history by converting the
// provided string into a series of messages. This allows the ChatGPT API to have
// prior knowledge of the provided context when generating responses.
//
// The context string should contain the text you want to provide as context,
// and the method will split it into messages, preserving punctuation and special
// characters.
func (c *Client) ProvideContext(context string) {
c.initHistory()
messages := createMessagesFromString(context)
c.History = append(c.History, messages...)
}
// Query sends a query to the API and returns the response as a string.
// It takes an input string as a parameter and returns a string containing
// the API response or an error if there's any issue during the process.
@@ -56,26 +104,21 @@ func (c *Client) WithModel(model string) *Client {
// call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) {
c.initHistory()
c.addQuery(input)
c.prepareQuery(input)
body, err := c.createBody(false)
if err != nil {
return "", err
}
raw, err := c.caller.Post(URL, body, false)
raw, err := c.caller.Post(CompletionURL, body, false)
if err != nil {
return "", err
}
if raw == nil {
return "", errors.New("empty response")
}
var response types.Response
if err := json.Unmarshal(raw, &response); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
var response types.CompletionsResponse
if err := c.processResponse(raw, &response); err != nil {
return "", err
}
if len(response.Choices) == 0 {
@@ -93,15 +136,14 @@ func (c *Client) Query(input string) (string, error) {
// input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error {
c.initHistory()
c.addQuery(input)
c.prepareQuery(input)
body, err := c.createBody(true)
if err != nil {
return err
}
result, err := c.caller.Post(URL, body, true)
result, err := c.caller.Post(CompletionURL, body, true)
if err != nil {
return err
}
@@ -111,21 +153,8 @@ func (c *Client) Stream(input string) error {
return nil
}
// ProvideContext adds custom context to the client's history by converting the
// provided string into a series of messages. This allows the ChatGPT API to have
// prior knowledge of the provided context when generating responses.
//
// The context string should contain the text you want to provide as context,
// and the method will split it into messages, preserving punctuation and special
// characters.
func (c *Client) ProvideContext(context string) {
c.initHistory()
messages := createMessagesFromString(context)
c.History = append(c.History, messages...)
}
func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.Request{
body := types.CompletionsRequest{
Messages: c.History,
Model: c.model,
Stream: stream,
@@ -158,6 +187,23 @@ func (c *Client) addQuery(query string) {
c.truncateHistory()
}
func (c *Client) prepareQuery(input string) {
c.initHistory()
c.addQuery(input)
}
func (c *Client) processResponse(raw []byte, v interface{}) error {
if raw == nil {
return errors.New(ErrEmptyResponse)
}
if err := json.Unmarshal(raw, v); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
return nil
}
func (c *Client) truncateHistory() {
tokens, rolling := countTokens(c.History)
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)

View File

@@ -7,6 +7,7 @@ import (
_ "github.com/golang/mock/mockgen/model"
"github.com/kardolus/chatgpt-cli/client"
"github.com/kardolus/chatgpt-cli/types"
"github.com/kardolus/chatgpt-cli/utils"
"testing"
. "github.com/onsi/gomega"
@@ -60,7 +61,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
errorMsg := "error message"
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, errors.New(errorMsg))
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -68,7 +69,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
it("throws an error when the response is empty", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
mockCaller.EXPECT().Post(client.URL, body, false).Return(nil, nil)
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(nil, nil)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -78,7 +79,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Post(client.URL, body, false).Return([]byte(malformed), nil)
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return([]byte(malformed), nil)
_, err := subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -87,7 +88,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
it("throws an error when the response is missing Choices", func() {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
response := &types.Response{
response := &types.CompletionsResponse{
ID: "id",
Object: "object",
Created: 0,
@@ -97,7 +98,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, body, false).Return(respBytes, nil)
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, nil)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -117,19 +118,19 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
FinishReason: "",
Index: 0,
}
response := &types.Response{
response := &types.CompletionsResponse{
ID: "id",
Object: "object",
Created: 0,
Model: client.GPTModel,
Model: client.DefaultGPTModel,
Choices: []types.Choice{choice},
}
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.URL, expectedBody, false).Return(respBytes, nil)
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, false).Return(respBytes, nil)
var request types.Request
var request types.CompletionsRequest
err = json.Unmarshal(expectedBody, &request)
Expect(err).NotTo(HaveOccurred())
@@ -228,7 +229,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
errorMsg := "error message"
mockCaller.EXPECT().Post(client.URL, body, true).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg))
err := subject.Stream(query)
Expect(err).To(HaveOccurred())
@@ -239,7 +240,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
testValidHTTPResponse := func(history []types.Message, expectedBody []byte) {
mockStore.EXPECT().Read().Return(history, nil).Times(1)
mockCaller.EXPECT().Post(client.URL, expectedBody, true).Return([]byte(answer), nil)
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil)
messages = createMessages(history, query)
@@ -278,6 +279,44 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
})
})
})
when("ListModels()", func() {
it("throws an error when the http callout fails", func() {
errorMsg := "error message"
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg))
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
it("throws an error when the response is empty", func() {
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil)
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("empty response"))
})
it("throws an error when the response is a malformed json", func() {
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil)
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
})
it("filters gpt models as expected", func() {
response, err := utils.FileToBytes("models.json")
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Get(client.ModelURL).Return(response, nil)
result, err := subject.ListModels()
Expect(err).NotTo(HaveOccurred())
Expect(result).NotTo(BeEmpty())
Expect(result).To(HaveLen(2))
Expect(result[0]).To(Equal("* gpt-3.5-turbo (current)"))
Expect(result[1]).To(Equal("- gpt-3.5-turbo-0301"))
})
})
when("ProvideContext()", func() {
it("updates the history with the provided context", func() {
context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
@@ -298,8 +337,8 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
}
func createBody(messages []types.Message, stream bool) ([]byte, error) {
req := types.Request{
Model: client.GPTModel,
req := types.CompletionsRequest{
Model: client.DefaultGPTModel,
Messages: messages,
Stream: stream,
}

View File

@@ -21,6 +21,7 @@ var (
clearHistory bool
showVersion bool
interactiveMode bool
listModels bool
modelName string
GitCommit string
GitVersion string
@@ -40,6 +41,7 @@ func main() {
rootCmd.PersistentFlags().BoolVarP(&queryMode, "query", "q", false, "Use query mode instead of stream mode")
rootCmd.PersistentFlags().BoolVarP(&clearHistory, "clear-history", "c", false, "Clear the history of ChatGPT CLI")
rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Display the version information")
rootCmd.PersistentFlags().BoolVarP(&listModels, "list-models", "l", false, "List available models")
rootCmd.PersistentFlags().StringVarP(&modelName, "model", "m", "", "Use a custom GPT model by specifying the model name")
viper.AutomaticEnv()
@@ -86,6 +88,18 @@ func run(cmd *cobra.Command, args []string) error {
return nil
}
if listModels {
models, err := client.ListModels()
if err != nil {
return err
}
fmt.Println("Available models:")
for _, model := range models {
fmt.Println(model)
}
return nil
}
if interactiveMode {
scanner := bufio.NewScanner(os.Stdin)
qNum := 1

View File

@@ -13,8 +13,20 @@ import (
"strings"
)
const (
bearer = "Bearer %s"
contentType = "application/json"
errFailedToRead = "failed to read response: %w"
errFailedToCreateRequest = "failed to create request: %w"
errFailedToMakeRequest = "failed to make request: %w"
errHTTP = "http error: %d"
headerAuthorization = "Authorization"
headerContentType = "Content-Type"
)
type Caller interface {
Post(url string, body []byte, stream bool) ([]byte, error)
Get(url string) ([]byte, error)
}
type RestCaller struct {
@@ -36,36 +48,12 @@ func (r *RestCaller) WithSecret(secret string) *RestCaller {
return r
}
func (r *RestCaller) Get(url string) ([]byte, error) {
return r.doRequest(http.MethodGet, url, nil, false)
}
func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error) {
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if r.secret != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", r.secret))
}
req.Header.Set("Content-Type", "application/json")
response, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer response.Body.Close()
if response.StatusCode >= 200 && response.StatusCode < 300 {
if stream {
return ProcessResponse(response.Body, os.Stdout), nil
} else {
result, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
return result, nil
}
}
return nil, fmt.Errorf("http error: %d", response.StatusCode)
return r.doRequest(http.MethodPost, url, body, stream)
}
func ProcessResponse(r io.Reader, w io.Writer) []byte {
@@ -103,3 +91,45 @@ func ProcessResponse(r io.Reader, w io.Writer) []byte {
}
return result
}
func (r *RestCaller) doRequest(method, url string, body []byte, stream bool) ([]byte, error) {
req, err := r.newRequest(method, url, body)
if err != nil {
return nil, fmt.Errorf(errFailedToCreateRequest, err)
}
response, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf(errFailedToMakeRequest, err)
}
defer response.Body.Close()
if response.StatusCode < 200 || response.StatusCode >= 300 {
return nil, fmt.Errorf(errHTTP, response.StatusCode)
}
if stream {
return ProcessResponse(response.Body, os.Stdout), nil
}
result, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf(errFailedToRead, err)
}
return result, nil
}
func (r *RestCaller) newRequest(method, url string, body []byte) (*http.Request, error) {
req, err := http.NewRequest(method, url, bytes.NewBuffer(body))
if err != nil {
return nil, err
}
if r.secret != "" {
req.Header.Set(headerAuthorization, fmt.Sprintf(bearer, r.secret))
}
req.Header.Set(headerContentType, contentType)
return req, nil
}

1565
resources/testdata/models.json vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
package types
type Request struct {
type CompletionsRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
@@ -11,7 +11,7 @@ type Message struct {
Content string `json:"content"`
}
type Response struct {
type CompletionsResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`

29
types/models.go Normal file
View File

@@ -0,0 +1,29 @@
package types
type ListModelsResponse struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group interface{} `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}

32
utils/testutils.go Normal file
View File

@@ -0,0 +1,32 @@
package utils
import (
. "github.com/onsi/gomega"
"io/ioutil"
"path"
"path/filepath"
"runtime"
"strings"
)
func FileToBytes(fileName string) ([]byte, error) {
_, thisFile, _, _ := runtime.Caller(0)
var (
urlPath string
err error
)
if strings.Contains(thisFile, "vendor") {
urlPath, err = filepath.Abs(path.Join(thisFile, "../../../../../..", "resources", "testdata", fileName))
} else {
urlPath, err = filepath.Abs(path.Join(thisFile, "../..", "resources", "testdata", fileName))
}
if err != nil {
return nil, err
}
Expect(urlPath).To(BeAnExistingFile())
return ioutil.ReadFile(urlPath)
}