feat(http): add custom CA certificate support for OIDC providers

This commit is contained in:
Matthias Wessendorf
2025-10-20 17:16:52 +02:00
committed by GitHub
parent 7f4edfd075
commit 49afbad502
5 changed files with 20 additions and 10 deletions

View File

@@ -108,7 +108,7 @@ func write401(w http.ResponseWriter, wwwAuthenticateHeader, errorType, message s
// - If ValidateToken is set, the exchanged token is then used against the Kubernetes API Server for TokenReview. // - If ValidateToken is set, the exchanged token is then used against the Kubernetes API Server for TokenReview.
// //
// see TestAuthorizationOidcTokenExchange // see TestAuthorizationOidcTokenExchange
func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler { func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier, httpClient *http.Client) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) { if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) {
@@ -159,7 +159,11 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
if err == nil && sts.IsEnabled() { if err == nil && sts.IsEnabled() {
var exchangedToken *oauth2.Token var exchangedToken *oauth2.Token
// If the token is valid, we can exchange it for a new token with the specified audience and scopes. // If the token is valid, we can exchange it for a new token with the specified audience and scopes.
exchangedToken, err = sts.ExternalAccountTokenExchange(r.Context(), &oauth2.Token{ ctx := r.Context()
if httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
exchangedToken, err = sts.ExternalAccountTokenExchange(ctx, &oauth2.Token{
AccessToken: claims.Token, AccessToken: claims.Token,
TokenType: "Bearer", TokenType: "Bearer",
}) })

View File

@@ -24,11 +24,11 @@ const (
sseMessageEndpoint = "/message" sseMessageEndpoint = "/message"
) )
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider) error { func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, httpClient *http.Client) error {
mux := http.NewServeMux() mux := http.NewServeMux()
wrappedMux := RequestMiddleware( wrappedMux := RequestMiddleware(
AuthorizationMiddleware(staticConfig, oidcProvider, mcpServer)(mux), AuthorizationMiddleware(staticConfig, oidcProvider, mcpServer, httpClient)(mux),
) )
httpServer := &http.Server{ httpServer := &http.Server{
@@ -44,7 +44,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
mux.HandleFunc(healthEndpoint, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc(healthEndpoint, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
mux.Handle("/.well-known/", WellKnownHandler(staticConfig)) mux.Handle("/.well-known/", WellKnownHandler(staticConfig, httpClient))
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()

View File

@@ -89,7 +89,7 @@ func (c *httpContext) beforeEach(t *testing.T) {
timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second) timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second)
group, gc := errgroup.WithContext(timeoutCtx) group, gc := errgroup.WithContext(timeoutCtx)
cancelCtx, c.StopServer = context.WithCancel(gc) cancelCtx, c.StopServer = context.WithCancel(gc)
group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider) }) group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider, nil) })
c.WaitForShutdown = group.Wait c.WaitForShutdown = group.Wait
// Wait for HTTP server to start (using net) // Wait for HTTP server to start (using net)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {

View File

@@ -25,19 +25,24 @@ type WellKnown struct {
authorizationUrl string authorizationUrl string
scopesSupported []string scopesSupported []string
disableDynamicClientRegistration bool disableDynamicClientRegistration bool
httpClient *http.Client
} }
var _ http.Handler = &WellKnown{} var _ http.Handler = &WellKnown{}
func WellKnownHandler(staticConfig *config.StaticConfig) http.Handler { func WellKnownHandler(staticConfig *config.StaticConfig, httpClient *http.Client) http.Handler {
authorizationUrl := staticConfig.AuthorizationURL authorizationUrl := staticConfig.AuthorizationURL
if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") { if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") {
authorizationUrl = strings.TrimSuffix(authorizationUrl, "/") authorizationUrl = strings.TrimSuffix(authorizationUrl, "/")
} }
if httpClient == nil {
httpClient = http.DefaultClient
}
return &WellKnown{ return &WellKnown{
authorizationUrl: authorizationUrl, authorizationUrl: authorizationUrl,
disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration, disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration,
scopesSupported: staticConfig.OAuthScopes, scopesSupported: staticConfig.OAuthScopes,
httpClient: httpClient,
} }
} }
@@ -51,7 +56,7 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request)
http.Error(writer, "Failed to create request: "+err.Error(), http.StatusInternalServerError) http.Error(writer, "Failed to create request: "+err.Error(), http.StatusInternalServerError)
return return
} }
resp, err := http.DefaultClient.Do(req.WithContext(request.Context())) resp, err := w.httpClient.Do(req.WithContext(request.Context()))
if err != nil { if err != nil {
http.Error(writer, "Failed to perform request: "+err.Error(), http.StatusInternalServerError) http.Error(writer, "Failed to perform request: "+err.Error(), http.StatusInternalServerError)
return return

View File

@@ -301,10 +301,11 @@ func (m *MCPServerOptions) Run() error {
} }
var oidcProvider *oidc.Provider var oidcProvider *oidc.Provider
var httpClient *http.Client
if m.StaticConfig.AuthorizationURL != "" { if m.StaticConfig.AuthorizationURL != "" {
ctx := context.Background() ctx := context.Background()
if m.StaticConfig.CertificateAuthority != "" { if m.StaticConfig.CertificateAuthority != "" {
httpClient := &http.Client{} httpClient = &http.Client{}
caCert, err := os.ReadFile(m.StaticConfig.CertificateAuthority) caCert, err := os.ReadFile(m.StaticConfig.CertificateAuthority)
if err != nil { if err != nil {
return fmt.Errorf("failed to read CA certificate from %s: %w", m.StaticConfig.CertificateAuthority, err) return fmt.Errorf("failed to read CA certificate from %s: %w", m.StaticConfig.CertificateAuthority, err)
@@ -341,7 +342,7 @@ func (m *MCPServerOptions) Run() error {
if m.StaticConfig.Port != "" { if m.StaticConfig.Port != "" {
ctx := context.Background() ctx := context.Background()
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider) return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider, httpClient)
} }
if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) { if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) {