Refactor HTTP client initialization with factory pattern

This commit is contained in:
kardolus
2023-11-10 12:01:38 -05:00
parent bd485025d0
commit 2e0b5445b6
7 changed files with 40 additions and 69 deletions

View File

@@ -62,15 +62,3 @@ func (mr *MockCallerMockRecorder) Post(arg0, arg1, arg2 interface{}) *gomock.Cal
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockCaller)(nil).Post), arg0, arg1, arg2)
}
// SetAPIKey mocks base method.
func (m *MockCaller) SetAPIKey(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAPIKey", arg0)
}
// SetAPIKey indicates an expected call of SetAPIKey.
func (mr *MockCallerMockRecorder) SetAPIKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAPIKey", reflect.TypeOf((*MockCaller)(nil).SetAPIKey), arg0)
}

View File

@@ -29,7 +29,7 @@ type Client struct {
historyStore history.HistoryStore
}
func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) (*Client, error) {
func New(callerFactory http.CallerFactory, cs config.ConfigStore, hs history.HistoryStore) (*Client, error) {
cm := configmanager.New(cs).WithEnvironment()
configuration := cm.Config
@@ -37,7 +37,8 @@ func New(caller http.Caller, cs config.ConfigStore, hs history.HistoryStore) (*C
return nil, errors.New("missing environment variable: " + cm.APIKeyEnvVarName())
}
caller.SetAPIKey(configuration.APIKey)
caller := callerFactory(configuration)
hs.SetThread(configuration.Thread)
return &Client{

View File

@@ -6,6 +6,7 @@ import (
"github.com/golang/mock/gomock"
_ "github.com/golang/mock/mockgen/model"
"github.com/kardolus/chatgpt-cli/client"
"github.com/kardolus/chatgpt-cli/http"
"github.com/kardolus/chatgpt-cli/types"
"github.com/kardolus/chatgpt-cli/utils"
"os"
@@ -60,7 +61,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockHistoryStore = NewMockHistoryStore(mockCtrl)
mockConfigStore = NewMockConfigStore(mockCtrl)
factory = newClientFactory(mockCaller, mockConfigStore, mockHistoryStore)
factory = newClientFactory(mockConfigStore, mockHistoryStore)
apiKeyEnvVar = strings.ToUpper(defaultName) + "_API_KEY"
Expect(os.Setenv(apiKeyEnvVar, envApiKey)).To(Succeed())
@@ -76,7 +77,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
_, err := client.New(mockCaller, mockConfigStore, mockHistoryStore)
_, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(apiKeyEnvVar))
@@ -244,11 +245,10 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
testValidHTTPResponse(subject, history, body, false)
})
it("ignores history when configured to do so", func() {
mockCaller.EXPECT().SetAPIKey(envApiKey).Times(1)
mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{OmitHistory: true}, nil).Times(1)
subject, err := client.New(mockCaller, mockConfigStore, mockHistoryStore)
subject, err := client.New(mockCallerFactory, mockConfigStore, mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
// Read and Write are never called on the history store
@@ -505,12 +505,11 @@ func createMessages(history []types.Message, query string) []types.Message {
}
type clientFactory struct {
mockCaller *MockCaller
mockConfigStore *MockConfigStore
mockHistoryStore *MockHistoryStore
}
func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
func newClientFactory(mcs *MockConfigStore, mhs *MockHistoryStore) *clientFactory {
mockConfigStore.EXPECT().ReadDefaults().Return(types.Config{
Name: defaultName,
Model: defaultModel,
@@ -527,29 +526,26 @@ func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStor
}).Times(1)
return &clientFactory{
mockCaller: mc,
mockConfigStore: mcs,
mockHistoryStore: mhs,
}
}
func (f *clientFactory) buildClientWithoutConfig() *client.Client {
f.mockCaller.EXPECT().SetAPIKey(envApiKey).Times(1)
f.mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
f.mockConfigStore.EXPECT().Read().Return(types.Config{}, nil).Times(1)
c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
return c.WithCapacity(50)
}
func (f *clientFactory) buildClientWithConfig(config types.Config) *client.Client {
f.mockCaller.EXPECT().SetAPIKey(envApiKey).Times(1)
f.mockHistoryStore.EXPECT().SetThread(defaultThread).Times(1)
f.mockConfigStore.EXPECT().Read().Return(config, nil).Times(1)
c, err := client.New(f.mockCaller, f.mockConfigStore, f.mockHistoryStore)
c, err := client.New(mockCallerFactory, f.mockConfigStore, f.mockHistoryStore)
Expect(err).NotTo(HaveOccurred())
return c.WithCapacity(50)
@@ -562,3 +558,7 @@ func (f *clientFactory) withoutHistory() {
func (f *clientFactory) withHistory(history []types.Message) {
f.mockHistoryStore.EXPECT().Read().Return(history, nil).Times(1)
}
func mockCallerFactory(cfg types.Config) http.Caller {
return mockCaller
}

View File

@@ -116,7 +116,7 @@ func run(cmd *cobra.Command, args []string) error {
}
hs, _ := history.New() // do not error out
client, err := client.New(http.New(), config.New(), hs)
client, err := client.New(http.RealCallerFactory, config.New(), hs)
if err != nil {
return err
}

View File

@@ -27,25 +27,27 @@ const (
type Caller interface {
Post(url string, body []byte, stream bool) ([]byte, error)
Get(url string) ([]byte, error)
SetAPIKey(secret string)
}
type RestCaller struct {
client *http.Client
secret string
config types.Config
}
// Ensure RestCaller implements Caller interface
var _ Caller = &RestCaller{}
func New() *RestCaller {
func New(cfg types.Config) *RestCaller {
return &RestCaller{
client: &http.Client{},
config: cfg,
}
}
func (r *RestCaller) SetAPIKey(secret string) {
r.secret = secret
type CallerFactory func(cfg types.Config) Caller
func RealCallerFactory(cfg types.Config) Caller {
return New(cfg)
}
func (r *RestCaller) Get(url string) ([]byte, error) {
@@ -136,8 +138,8 @@ func (r *RestCaller) newRequest(method, url string, body []byte) (*http.Request,
return nil, err
}
if r.secret != "" {
req.Header.Set(headerAuthorization, fmt.Sprintf(bearer, r.secret))
if r.config.APIKey != "" {
req.Header.Set(headerAuthorization, fmt.Sprintf(bearer, r.config.APIKey))
}
req.Header.Set(headerContentType, contentType)

View File

@@ -21,7 +21,7 @@ func TestContract(t *testing.T) {
func testContract(t *testing.T, when spec.G, it spec.S) {
var (
restCaller *http.RestCaller
defaults types.Config
cfg types.Config
)
it.Before(func() {
@@ -30,10 +30,10 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
apiKey := os.Getenv(configmanager.New(config.New()).WithEnvironment().APIKeyEnvVarName())
Expect(apiKey).NotTo(BeEmpty())
restCaller = http.New()
restCaller.SetAPIKey(apiKey)
cfg = config.New().ReadDefaults()
cfg.APIKey = apiKey
defaults = config.New().ReadDefaults()
restCaller = http.New(cfg)
})
when("accessing the completion endpoint", func() {
@@ -41,16 +41,16 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
body := types.CompletionsRequest{
Messages: []types.Message{{
Role: client.SystemRole,
Content: defaults.Role,
Content: cfg.Role,
}},
Model: defaults.Model,
Model: cfg.Model,
Stream: false,
}
bytes, err := json.Marshal(body)
Expect(err).NotTo(HaveOccurred())
resp, err := restCaller.Post(defaults.URL+defaults.CompletionsPath, bytes, false)
resp, err := restCaller.Post(cfg.URL+cfg.CompletionsPath, bytes, false)
Expect(err).NotTo(HaveOccurred())
var data types.CompletionsResponse
@@ -66,22 +66,19 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
})
it("should return an error response with appropriate error details", func() {
// Set the wrong API key
restCaller.SetAPIKey("wrong-key")
body := types.CompletionsRequest{
Messages: []types.Message{{
Role: client.SystemRole,
Content: defaults.Role,
Content: cfg.Role,
}},
Model: defaults.Model,
Model: "no-such-model",
Stream: false,
}
bytes, err := json.Marshal(body)
Expect(err).NotTo(HaveOccurred())
resp, err := restCaller.Post(defaults.URL+defaults.CompletionsPath, bytes, false)
resp, err := restCaller.Post(cfg.URL+cfg.CompletionsPath, bytes, false)
Expect(err).To(HaveOccurred())
var errorData types.ErrorResponse
@@ -96,7 +93,7 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
when("accessing the models endpoint", func() {
it("should have the expected keys in the response", func() {
resp, err := restCaller.Get(defaults.URL + defaults.ModelsPath)
resp, err := restCaller.Get(cfg.URL + cfg.ModelsPath)
Expect(err).NotTo(HaveOccurred())
var data types.ListModelsResponse
@@ -112,8 +109,6 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
Expect(model.Object).ShouldNot(BeEmpty(), "Expected Model Object to be present in the response")
Expect(model.Created).ShouldNot(BeZero(), "Expected Model Created to be present in the response")
Expect(model.OwnedBy).ShouldNot(BeEmpty(), "Expected Model OwnedBy to be present in the response")
Expect(model.Permission).ShouldNot(BeNil(), "Expected Model Permission to be present in the response")
Expect(model.Root).ShouldNot(BeEmpty(), "Expected Model Root to be present in the response")
}
})
})

View File

@@ -6,24 +6,9 @@ type ListModelsResponse struct {
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group interface{} `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Parent *string `json:"parent"`
}