Add method to overwrite the service URL

This commit is contained in:
kardolus
2023-06-14 10:54:58 -04:00
parent 30104b1a8e
commit e34a8759e0
5 changed files with 30 additions and 19 deletions

View File

@@ -18,11 +18,12 @@ const (
AssistantRole = "assistant"
ErrEmptyResponse = "empty response"
DefaultGPTModel = "gpt-3.5-turbo"
DefaultServiceURL = "https://api.openai.com"
CompletionPath = "/v1/chat/completions"
ModelPath = "/v1/models"
MaxTokenBufferPercentage = 20
MaxTokenSize = 4096
SystemRole = "system"
CompletionURL = "https://api.openai.com/v1/chat/completions"
ModelURL = "https://api.openai.com/v1/models"
UserRole = "user"
gptPrefix = "gpt"
)
@@ -33,6 +34,7 @@ type Client struct {
caller http.Caller
capacity int
historyStore history.HistoryStore
serviceURL string
}
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client {
@@ -40,6 +42,7 @@ func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Cl
caller: caller,
historyStore: hs,
capacity: MaxTokenSize,
serviceURL: DefaultServiceURL,
}
// do not error out when the config cannot be read
@@ -62,6 +65,11 @@ func (c *Client) WithModel(model string) *Client {
return c
}
func (c *Client) WithServiceURL(url string) *Client {
c.serviceURL = url
return c
}
// ListModels retrieves a list of all available models from the OpenAI API.
// The models are returned as a slice of strings, each entry representing a model ID.
// Models that have an ID starting with 'gpt' are included.
@@ -71,7 +79,7 @@ func (c *Client) WithModel(model string) *Client {
func (c *Client) ListModels() ([]string, error) {
var result []string
raw, err := c.caller.Get(ModelURL)
raw, err := c.caller.Get(c.getEndpoint(ModelPath))
if err != nil {
return nil, err
}
@@ -121,7 +129,7 @@ func (c *Client) Query(input string) (string, error) {
return "", err
}
raw, err := c.caller.Post(CompletionURL, body, false)
raw, err := c.caller.Post(c.getEndpoint(CompletionPath), body, false)
if err != nil {
return "", err
}
@@ -153,7 +161,7 @@ func (c *Client) Stream(input string) error {
return err
}
result, err := c.caller.Post(CompletionURL, body, true)
result, err := c.caller.Post(c.getEndpoint(CompletionPath), body, true)
if err != nil {
return err
}
@@ -197,6 +205,10 @@ func (c *Client) addQuery(query string) {
c.truncateHistory()
}
func (c *Client) getEndpoint(path string) string {
return c.serviceURL + path
}
func (c *Client) prepareQuery(input string) {
c.initHistory()
c.addQuery(input)

View File

@@ -113,7 +113,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
respBytes, err := tt.setupPostReturn()
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, tt.postError)
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, false).Return(respBytes, tt.postError)
_, err = subject.Query(query)
Expect(err).To(HaveOccurred())
@@ -143,7 +143,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
respBytes, err := json.Marshal(response)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, false).Return(respBytes, nil)
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, false).Return(respBytes, nil)
var request types.CompletionsRequest
err = json.Unmarshal(expectedBody, &request)
@@ -283,7 +283,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
Expect(err).NotTo(HaveOccurred())
errorMsg := "error message"
mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, true).Return(nil, errors.New(errorMsg))
err := subject.Stream(query)
Expect(err).To(HaveOccurred())
@@ -297,7 +297,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
body, err = createBody(messages, subject.Model, true)
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil)
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, true).Return([]byte(answer), nil)
messages = createMessages(history, query)
@@ -351,7 +351,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
errorMsg := "error message"
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg))
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, errors.New(errorMsg))
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
@@ -360,7 +360,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
it("throws an error when the response is empty", func() {
subject := factory.buildClientWithoutConfig()
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil)
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, nil)
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
@@ -370,7 +370,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
subject := factory.buildClientWithoutConfig()
malformed := `{"invalid":"json"` // missing closing brace
mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil)
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return([]byte(malformed), nil)
_, err := subject.ListModels()
Expect(err).To(HaveOccurred())
@@ -382,7 +382,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
response, err := utils.FileToBytes("models.json")
Expect(err).NotTo(HaveOccurred())
mockCaller.EXPECT().Get(client.ModelURL).Return(response, nil)
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(response, nil)
result, err := subject.ListModels()
Expect(err).NotTo(HaveOccurred())

View File

@@ -12,7 +12,7 @@ import (
"github.com/kardolus/chatgpt-cli/utils"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"io/ioutil"
"io"
"os"
"strings"
)
@@ -68,7 +68,7 @@ func run(cmd *cobra.Command, args []string) error {
// Check if there is input from the pipe (stdin)
stat, _ := os.Stdin.Stat()
if (stat.Mode() & os.ModeCharDevice) == 0 {
pipeContent, err := ioutil.ReadAll(os.Stdin)
pipeContent, err := io.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("failed to read from pipe: %w", err)
}

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"github.com/kardolus/chatgpt-cli/types"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
@@ -112,7 +111,7 @@ func (r *RestCaller) doRequest(method, url string, body []byte, stream bool) ([]
return ProcessResponse(response.Body, os.Stdout), nil
}
result, err := ioutil.ReadAll(response.Body)
result, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf(errFailedToRead, err)
}

View File

@@ -43,7 +43,7 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
bytes, err := json.Marshal(body)
Expect(err).NotTo(HaveOccurred())
resp, err := restCaller.Post(client.CompletionURL, bytes, false)
resp, err := restCaller.Post(client.DefaultServiceURL+client.CompletionPath, bytes, false)
Expect(err).NotTo(HaveOccurred())
var data types.CompletionsResponse
@@ -61,7 +61,7 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
when("accessing the models endpoint", func() {
it("should have the expected keys in the response", func() {
resp, err := restCaller.Get(client.ModelURL)
resp, err := restCaller.Get(client.DefaultServiceURL + client.ModelPath)
Expect(err).NotTo(HaveOccurred())
var data types.ListModelsResponse