diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index cded7f3..19f6170 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -108,7 +108,7 @@ func write401(w http.ResponseWriter, wwwAuthenticateHeader, errorType, message s // - If ValidateToken is set, the exchanged token is then used against the Kubernetes API Server for TokenReview. // // see TestAuthorizationOidcTokenExchange -func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler { +func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier, httpClient *http.Client) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) { @@ -159,7 +159,11 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi if err == nil && sts.IsEnabled() { var exchangedToken *oauth2.Token // If the token is valid, we can exchange it for a new token with the specified audience and scopes. - exchangedToken, err = sts.ExternalAccountTokenExchange(r.Context(), &oauth2.Token{ + ctx := r.Context() + if httpClient != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + } + exchangedToken, err = sts.ExternalAccountTokenExchange(ctx, &oauth2.Token{ AccessToken: claims.Token, TokenType: "Bearer", }) diff --git a/pkg/http/http.go b/pkg/http/http.go index 3f74c09..8001462 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -24,11 +24,11 @@ const ( sseMessageEndpoint = "/message" ) -func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider) error { +func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, httpClient *http.Client) error { mux := http.NewServeMux() wrappedMux := RequestMiddleware( - AuthorizationMiddleware(staticConfig, oidcProvider, mcpServer)(mux), + AuthorizationMiddleware(staticConfig, oidcProvider, mcpServer, httpClient)(mux), ) httpServer := &http.Server{ @@ -44,7 +44,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat mux.HandleFunc(healthEndpoint, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) - mux.Handle("/.well-known/", WellKnownHandler(staticConfig)) + mux.Handle("/.well-known/", WellKnownHandler(staticConfig, httpClient)) ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 36e7f88..cac2331 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -89,7 +89,7 @@ func (c *httpContext) beforeEach(t *testing.T) { timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second) group, gc := errgroup.WithContext(timeoutCtx) cancelCtx, c.StopServer = context.WithCancel(gc) - group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider) }) + group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider, nil) }) c.WaitForShutdown = group.Wait // Wait for HTTP server to start (using net) for i := 0; i < 10; i++ { diff --git a/pkg/http/wellknown.go b/pkg/http/wellknown.go index 0d80221..6c065fa 100644 --- a/pkg/http/wellknown.go +++ b/pkg/http/wellknown.go @@ -25,19 +25,24 @@ type WellKnown struct { authorizationUrl string scopesSupported []string disableDynamicClientRegistration bool + httpClient *http.Client } var _ http.Handler = &WellKnown{} -func WellKnownHandler(staticConfig *config.StaticConfig) http.Handler { +func WellKnownHandler(staticConfig *config.StaticConfig, httpClient *http.Client) http.Handler { authorizationUrl := staticConfig.AuthorizationURL if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") { authorizationUrl = strings.TrimSuffix(authorizationUrl, "/") } + if httpClient == nil { + httpClient = http.DefaultClient + } return &WellKnown{ authorizationUrl: authorizationUrl, disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration, scopesSupported: staticConfig.OAuthScopes, + httpClient: httpClient, } } @@ -51,7 +56,7 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) http.Error(writer, "Failed to create request: "+err.Error(), http.StatusInternalServerError) return } - resp, err := http.DefaultClient.Do(req.WithContext(request.Context())) + resp, err := w.httpClient.Do(req.WithContext(request.Context())) if err != nil { http.Error(writer, "Failed to perform request: "+err.Error(), http.StatusInternalServerError) return diff --git a/pkg/kubernetes-mcp-server/cmd/root.go b/pkg/kubernetes-mcp-server/cmd/root.go index 1e91d0c..db1782a 100644 --- a/pkg/kubernetes-mcp-server/cmd/root.go +++ b/pkg/kubernetes-mcp-server/cmd/root.go @@ -301,10 +301,11 @@ func (m *MCPServerOptions) Run() error { } var oidcProvider *oidc.Provider + var httpClient *http.Client if m.StaticConfig.AuthorizationURL != "" { ctx := context.Background() if m.StaticConfig.CertificateAuthority != "" { - httpClient := &http.Client{} + httpClient = &http.Client{} caCert, err := os.ReadFile(m.StaticConfig.CertificateAuthority) if err != nil { return fmt.Errorf("failed to read CA certificate from %s: %w", m.StaticConfig.CertificateAuthority, err) @@ -341,7 +342,7 @@ func (m *MCPServerOptions) Run() error { if m.StaticConfig.Port != "" { ctx := context.Background() - return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider) + return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider, httpClient) } if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) {