mirror of
https://github.com/kardolus/chatgpt-cli.git
synced 2024-09-08 23:15:00 +03:00
Add method to overwrite the service URL
This commit is contained in:
@@ -18,11 +18,12 @@ const (
|
||||
AssistantRole = "assistant"
|
||||
ErrEmptyResponse = "empty response"
|
||||
DefaultGPTModel = "gpt-3.5-turbo"
|
||||
DefaultServiceURL = "https://api.openai.com"
|
||||
CompletionPath = "/v1/chat/completions"
|
||||
ModelPath = "/v1/models"
|
||||
MaxTokenBufferPercentage = 20
|
||||
MaxTokenSize = 4096
|
||||
SystemRole = "system"
|
||||
CompletionURL = "https://api.openai.com/v1/chat/completions"
|
||||
ModelURL = "https://api.openai.com/v1/models"
|
||||
UserRole = "user"
|
||||
gptPrefix = "gpt"
|
||||
)
|
||||
@@ -33,6 +34,7 @@ type Client struct {
|
||||
caller http.Caller
|
||||
capacity int
|
||||
historyStore history.HistoryStore
|
||||
serviceURL string
|
||||
}
|
||||
|
||||
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Client {
|
||||
@@ -40,6 +42,7 @@ func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) *Cl
|
||||
caller: caller,
|
||||
historyStore: hs,
|
||||
capacity: MaxTokenSize,
|
||||
serviceURL: DefaultServiceURL,
|
||||
}
|
||||
|
||||
// do not error out when the config cannot be read
|
||||
@@ -62,6 +65,11 @@ func (c *Client) WithModel(model string) *Client {
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) WithServiceURL(url string) *Client {
|
||||
c.serviceURL = url
|
||||
return c
|
||||
}
|
||||
|
||||
// ListModels retrieves a list of all available models from the OpenAI API.
|
||||
// The models are returned as a slice of strings, each entry representing a model ID.
|
||||
// Models that have an ID starting with 'gpt' are included.
|
||||
@@ -71,7 +79,7 @@ func (c *Client) WithModel(model string) *Client {
|
||||
func (c *Client) ListModels() ([]string, error) {
|
||||
var result []string
|
||||
|
||||
raw, err := c.caller.Get(ModelURL)
|
||||
raw, err := c.caller.Get(c.getEndpoint(ModelPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -121,7 +129,7 @@ func (c *Client) Query(input string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
raw, err := c.caller.Post(CompletionURL, body, false)
|
||||
raw, err := c.caller.Post(c.getEndpoint(CompletionPath), body, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -153,7 +161,7 @@ func (c *Client) Stream(input string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := c.caller.Post(CompletionURL, body, true)
|
||||
result, err := c.caller.Post(c.getEndpoint(CompletionPath), body, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -197,6 +205,10 @@ func (c *Client) addQuery(query string) {
|
||||
c.truncateHistory()
|
||||
}
|
||||
|
||||
func (c *Client) getEndpoint(path string) string {
|
||||
return c.serviceURL + path
|
||||
}
|
||||
|
||||
func (c *Client) prepareQuery(input string) {
|
||||
c.initHistory()
|
||||
c.addQuery(input)
|
||||
|
||||
@@ -113,7 +113,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
respBytes, err := tt.setupPostReturn()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, false).Return(respBytes, tt.postError)
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, false).Return(respBytes, tt.postError)
|
||||
|
||||
_, err = subject.Query(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -143,7 +143,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
|
||||
respBytes, err := json.Marshal(response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, false).Return(respBytes, nil)
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, false).Return(respBytes, nil)
|
||||
|
||||
var request types.CompletionsRequest
|
||||
err = json.Unmarshal(expectedBody, &request)
|
||||
@@ -283,7 +283,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, body, true).Return(nil, errors.New(errorMsg))
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, body, true).Return(nil, errors.New(errorMsg))
|
||||
|
||||
err := subject.Stream(query)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -297,7 +297,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
body, err = createBody(messages, subject.Model, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockCaller.EXPECT().Post(client.CompletionURL, expectedBody, true).Return([]byte(answer), nil)
|
||||
mockCaller.EXPECT().Post(client.DefaultServiceURL+client.CompletionPath, expectedBody, true).Return([]byte(answer), nil)
|
||||
|
||||
messages = createMessages(history, query)
|
||||
|
||||
@@ -351,7 +351,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
errorMsg := "error message"
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, errors.New(errorMsg))
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, errors.New(errorMsg))
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -360,7 +360,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
it("throws an error when the response is empty", func() {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(nil, nil)
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(nil, nil)
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -370,7 +370,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
subject := factory.buildClientWithoutConfig()
|
||||
|
||||
malformed := `{"invalid":"json"` // missing closing brace
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return([]byte(malformed), nil)
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return([]byte(malformed), nil)
|
||||
|
||||
_, err := subject.ListModels()
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -382,7 +382,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
|
||||
response, err := utils.FileToBytes("models.json")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
mockCaller.EXPECT().Get(client.ModelURL).Return(response, nil)
|
||||
mockCaller.EXPECT().Get(client.DefaultServiceURL+client.ModelPath).Return(response, nil)
|
||||
|
||||
result, err := subject.ListModels()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Reference in New Issue
Block a user