feat(auth): configurable audience validation (#251)

Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
Marc Nuri
2025-08-08 08:50:50 +03:00
committed by GitHub
parent b0da9fb459
commit 7b11c1667a
6 changed files with 117 additions and 54 deletions

View File

@@ -19,13 +19,17 @@ type StaticConfig struct {
// When true, expose only tools annotated with readOnlyHint=true
ReadOnly bool `toml:"read_only,omitempty"`
// When true, disable tools annotated with destructiveHint=true
DisableDestructive bool `toml:"disable_destructive,omitempty"`
EnabledTools []string `toml:"enabled_tools,omitempty"`
DisabledTools []string `toml:"disabled_tools,omitempty"`
RequireOAuth bool `toml:"require_oauth,omitempty"`
AuthorizationURL string `toml:"authorization_url,omitempty"`
CertificateAuthority string `toml:"certificate_authority,omitempty"`
ServerURL string `toml:"server_url,omitempty"`
DisableDestructive bool `toml:"disable_destructive,omitempty"`
EnabledTools []string `toml:"enabled_tools,omitempty"`
DisabledTools []string `toml:"disabled_tools,omitempty"`
RequireOAuth bool `toml:"require_oauth,omitempty"`
// OAuthAudience is the valid audience for the OAuth tokens, used for offline JWT claim validation.
OAuthAudience string `toml:"oauth_audience,omitempty"`
// AuthorizationURL is the URL of the OIDC authorization server.
// It is used for token validation and for STS token exchange.
AuthorizationURL string `toml:"authorization_url,omitempty"`
CertificateAuthority string `toml:"certificate_authority,omitempty"`
ServerURL string `toml:"server_url,omitempty"`
}
type GroupVersionKind struct {

View File

@@ -41,8 +41,9 @@ type KubernetesApiTokenVerifier interface {
// 2. requireOAuth is set to true, server is protected:
//
// 2.1. Raw Token Validation (oidcProvider is nil):
// - The token is validated offline for basic sanity checks (audience and expiration).
// - The token is then used against the Kubernetes API Server for TokenReview.
// - The token is validated offline for basic sanity checks (expiration).
// - If audience is set, the token is validated against the audience.
// - The token is then used against the Kubernetes API Server for TokenReview.
//
// 2.2. OIDC Provider Validation (oidcProvider is not nil):
// - The token is validated offline for basic sanity checks (audience and expiration).
@@ -50,7 +51,7 @@ type KubernetesApiTokenVerifier interface {
// - The token is then used against the Kubernetes API Server for TokenReview.
//
// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx):
func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
func AuthorizationMiddleware(requireOAuth bool, audience string, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) {
@@ -62,13 +63,16 @@ func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, ver
return
}
audience := Audience
wwwAuthenticateHeader := "Bearer realm=\"Kubernetes MCP Server\""
if audience != "" {
wwwAuthenticateHeader += fmt.Sprintf(`, audience="%s"`, audience)
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience))
w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader+", error=\"missing_token\"")
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
return
}
@@ -77,12 +81,15 @@ func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, ver
claims, err := ParseJWTClaims(token)
if err == nil && claims != nil {
err = claims.Validate(r.Context(), audience, oidcProvider)
err = claims.ValidateOffline(audience)
}
if err == nil && claims != nil {
err = claims.ValidateWithProvider(r.Context(), audience, oidcProvider)
}
if err != nil {
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader+", error=\"invalid_token\"")
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
return
}
@@ -147,16 +154,24 @@ func (c *JWTClaims) GetScopes() []string {
return strings.Fields(c.Scope)
}
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
func (c *JWTClaims) Validate(ctx context.Context, audience string, provider *oidc.Provider) error {
if err := c.Claims.Validate(jwt.Expected{AnyAudience: jwt.Audience{audience}}); err != nil {
// ValidateOffline Checks if the JWT claims are valid and if the audience matches the expected one.
func (c *JWTClaims) ValidateOffline(audience string) error {
expected := jwt.Expected{}
if audience != "" {
expected.AnyAudience = jwt.Audience{audience}
}
if err := c.Validate(expected); err != nil {
return fmt.Errorf("JWT token validation error: %v", err)
}
return nil
}
// ValidateWithProvider validates the JWT claims against the OIDC provider.
func (c *JWTClaims) ValidateWithProvider(ctx context.Context, audience string, provider *oidc.Provider) error {
if provider != nil {
verifier := provider.Verifier(&oidc.Config{
ClientID: audience,
})
_, err := verifier.Verify(ctx, c.Token)
if err != nil {
return fmt.Errorf("OIDC token validation error: %v", err)

View File

@@ -104,14 +104,14 @@ func TestParseJWTClaimsPayloadInvalid(t *testing.T) {
})
}
func TestJWTTokenValidate(t *testing.T) {
func TestJWTTokenValidateOffline(t *testing.T) {
t.Run("expired token returns error", func(t *testing.T) {
claims, err := ParseJWTClaims(tokenBasicExpired)
if err != nil {
t.Fatalf("expected no error for expired token parsing, got %v", err)
}
err = claims.Validate(t.Context(), "mcp-server", nil)
err = claims.ValidateOffline("mcp-server")
if err == nil {
t.Fatalf("expected error for expired token, got nil")
}
@@ -130,7 +130,7 @@ func TestJWTTokenValidate(t *testing.T) {
t.Fatalf("expected claims to be returned, got nil")
}
err = claims.Validate(t.Context(), "mcp-server", nil)
err = claims.ValidateOffline("mcp-server")
if err != nil {
t.Fatalf("expected no error for valid audience, got %v", err)
}
@@ -145,7 +145,7 @@ func TestJWTTokenValidate(t *testing.T) {
t.Fatalf("expected claims to be returned, got nil")
}
err = claims.Validate(t.Context(), "missing-audience", nil)
err = claims.ValidateOffline("missing-audience")
if err == nil {
t.Fatalf("expected error for token with wrong audience, got nil")
}

View File

@@ -28,7 +28,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
mux := http.NewServeMux()
wrappedMux := RequestMiddleware(
AuthorizationMiddleware(staticConfig.RequireOAuth, oidcProvider, mcpServer)(mux),
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.OAuthAudience, oidcProvider, mcpServer)(mux),
)
httpServer := &http.Server{

View File

@@ -390,7 +390,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
t.Run("Protected resource with MISSING Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="missing_token"`
expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
@@ -415,7 +415,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with INCOMPATIBLE Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="missing_token"`
expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
@@ -432,7 +432,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer invalid_base64"+tokenBasicNotExpired)
req.Header.Set("Authorization", "Bearer "+strings.ReplaceAll(tokenBasicNotExpired, ".", ".invalid"))
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
@@ -445,13 +445,13 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
t.Run("Protected resource with INVALID Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"`
expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with INVALID Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
!strings.Contains(ctx.LogBuffer.String(), "error: failed to parse JWT token: illegal base64 data") {
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
}
@@ -476,22 +476,53 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
t.Run("Protected resource with EXPIRED Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"`
expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with EXPIRED Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
!strings.Contains(ctx.LogBuffer.String(), "validation failed, token is expired (exp)") {
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
}
})
})
// Invalid audience claim Bearer token
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience"}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+tokenBasicExpired)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with INVALID AUDIENCE Authorization header returns 401 - Unauthorized", func(t *testing.T) {
if resp.StatusCode != 401 {
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
}
})
t.Run("Protected resource with INVALID AUDIENCE Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="expected-audience", error="invalid_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with INVALID AUDIENCE Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
!strings.Contains(ctx.LogBuffer.String(), "invalid audience claim (aud)") {
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
}
})
})
// Failed OIDC validation
key, oidcProvider, httpServer := NewOidcTestServer(t)
t.Cleanup(httpServer.Close)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
@@ -528,7 +559,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
"aud": "mcp-server"
}`
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
@@ -576,30 +607,37 @@ func TestAuthorizationRequireOAuthFalse(t *testing.T) {
}
func TestAuthorizationRawToken(t *testing.T) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tokenReviewSuccessful))
return
cases := []string{
"",
"mcp-server",
}
for _, audience := range cases {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: audience}}, func(ctx *httpContext) {
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tokenReviewSuccessful))
return
}
}))
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
}))
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run("Protected resource with VALID Authorization header returns 200 - OK", func(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run("Protected resource with audience = '"+audience+"' with VALID Authorization header returns 200 - OK", func(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
}
})
})
})
}
}
func TestAuthorizationOidcToken(t *testing.T) {
@@ -611,7 +649,7 @@ func TestAuthorizationOidcToken(t *testing.T) {
"aud": "mcp-server"
}`
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
w.Header().Set("Content-Type", "application/json")

View File

@@ -62,6 +62,7 @@ type MCPServerOptions struct {
ReadOnly bool
DisableDestructive bool
RequireOAuth bool
OAuthAudience string
AuthorizationURL string
CertificateAuthority string
ServerURL string
@@ -119,6 +120,8 @@ func NewMCPServer(streams genericiooptions.IOStreams) *cobra.Command {
cmd.Flags().BoolVar(&o.DisableDestructive, "disable-destructive", o.DisableDestructive, "If true, tools annotated with destructiveHint=true are disabled")
cmd.Flags().BoolVar(&o.RequireOAuth, "require-oauth", o.RequireOAuth, "If true, requires OAuth authorization as defined in the Model Context Protocol (MCP) specification. This flag is ignored if transport type is stdio")
_ = cmd.Flags().MarkHidden("require-oauth")
cmd.Flags().StringVar(&o.OAuthAudience, "oauth-audience", o.OAuthAudience, "OAuth audience for token claims validation. Optional. If not set, the audience is not validated. Only valid if require-oauth is enabled.")
_ = cmd.Flags().MarkHidden("oauth-audience")
cmd.Flags().StringVar(&o.AuthorizationURL, "authorization-url", o.AuthorizationURL, "OAuth authorization server URL for protected resource endpoint. If not provided, the Kubernetes API server host will be used. Only valid if require-oauth is enabled.")
_ = cmd.Flags().MarkHidden("authorization-url")
cmd.Flags().StringVar(&o.ServerURL, "server-url", o.ServerURL, "Server URL of this application. Optional. If set, this url will be served in protected resource metadata endpoint and tokens will be validated with this audience. If not set, expected audience is kubernetes-mcp-server. Only valid if require-oauth is enabled.")
@@ -179,6 +182,9 @@ func (m *MCPServerOptions) loadFlags(cmd *cobra.Command) {
if cmd.Flag("require-oauth").Changed {
m.StaticConfig.RequireOAuth = m.RequireOAuth
}
if cmd.Flag("oauth-audience").Changed {
m.StaticConfig.OAuthAudience = m.OAuthAudience
}
if cmd.Flag("authorization-url").Changed {
m.StaticConfig.AuthorizationURL = m.AuthorizationURL
}