refactor(auth): consolidate JWT validation into single method (#238)

Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
Marc Nuri
2025-08-06 13:17:44 +03:00
committed by GitHub
parent 4302a438ab
commit 4dcede178b
3 changed files with 24 additions and 38 deletions

View File

@@ -56,7 +56,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
// rejected already.
claims, err := ParseJWTClaims(token)
if err == nil && claims != nil {
err = claims.Validate(audience)
err = claims.Validate(r.Context(), audience, oidcProvider)
}
if err != nil {
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
@@ -70,21 +70,6 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
return
}
if oidcProvider != nil {
// If OIDC Provider is configured, this token must be validated against it.
if err := validateTokenWithOIDC(r.Context(), oidcProvider, token, audience); err != nil {
klog.V(1).Infof("Authentication failed - OIDC token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
if serverURL == "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
} else {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
}
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
return
}
}
// Scopes are likely to be used for authorization.
scopes := claims.GetScopes()
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
@@ -138,6 +123,7 @@ var allSignatureAlgorithms = []jose.SignatureAlgorithm{
type JWTClaims struct {
jwt.Claims
Token string `json:"-"`
Scope string `json:"scope,omitempty"`
}
@@ -149,10 +135,21 @@ func (c *JWTClaims) GetScopes() []string {
}
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
func (c *JWTClaims) Validate(audience string) error {
return c.Claims.Validate(jwt.Expected{
AnyAudience: jwt.Audience{audience},
})
func (c *JWTClaims) Validate(ctx context.Context, audience string, provider *oidc.Provider) error {
if err := c.Claims.Validate(jwt.Expected{AnyAudience: jwt.Audience{audience}}); err != nil {
return fmt.Errorf("JWT token validation error: %v", err)
}
if provider != nil {
verifier := provider.Verifier(&oidc.Config{
ClientID: audience,
})
_, err := verifier.Verify(ctx, c.Token)
if err != nil {
return fmt.Errorf("OIDC token validation error: %v", err)
}
}
return nil
}
func ParseJWTClaims(token string) (*JWTClaims, error) {
@@ -162,18 +159,6 @@ func ParseJWTClaims(token string) (*JWTClaims, error) {
}
claims := &JWTClaims{}
err = tkn.UnsafeClaimsWithoutVerification(claims)
claims.Token = token
return claims, err
}
func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {
verifier := provider.Verifier(&oidc.Config{
ClientID: audience,
})
_, err := verifier.Verify(ctx, token)
if err != nil {
return fmt.Errorf("JWT token verification failed: %v", err)
}
return nil
}

View File

@@ -111,7 +111,7 @@ func TestJWTTokenValidate(t *testing.T) {
t.Fatalf("expected no error for expired token parsing, got %v", err)
}
err = claims.Validate("kubernetes-mcp-server")
err = claims.Validate(t.Context(), "kubernetes-mcp-server", nil)
if err == nil {
t.Fatalf("expected error for expired token, got nil")
}
@@ -130,7 +130,7 @@ func TestJWTTokenValidate(t *testing.T) {
t.Fatalf("expected claims to be returned, got nil")
}
err = claims.Validate("kubernetes-mcp-server")
err = claims.Validate(t.Context(), "kubernetes-mcp-server", nil)
if err != nil {
t.Fatalf("expected no error for valid audience, got %v", err)
}
@@ -145,7 +145,7 @@ func TestJWTTokenValidate(t *testing.T) {
t.Fatalf("expected claims to be returned, got nil")
}
err = claims.Validate("missing-audience")
err = claims.Validate(t.Context(), "missing-audience", nil)
if err == nil {
t.Fatalf("expected error for token with wrong audience, got nil")
}

View File

@@ -127,6 +127,7 @@ func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *
t.Fatalf("failed to generate private key for oidc: %v", err)
}
oidcServer := &oidctest.Server{
Algorithms: []string{oidc.RS256, oidc.ES256},
PublicKeys: []oidctest.PublicKey{
{
PublicKey: privateKey.Public(),
@@ -470,8 +471,8 @@ func TestAuthorizationUnauthorized(t *testing.T) {
}
})
t.Run("Protected resource with INVALID OIDC Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - OIDC token validation error") &&
!strings.Contains(ctx.LogBuffer.String(), "JWT token verification failed: oidc: id token issued by a different provider") {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
!strings.Contains(ctx.LogBuffer.String(), "OIDC token validation error: failed to verify signature") {
t.Errorf("Expected log entry for OIDC validation error, got: %s", ctx.LogBuffer.String())
}
})