mirror of
https://github.com/openshift/openshift-mcp-server.git
synced 2025-10-17 14:27:48 +03:00
refactor(auth): consolidate JWT validation into single method (#238)
Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
@@ -56,7 +56,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
|
||||
// rejected already.
|
||||
claims, err := ParseJWTClaims(token)
|
||||
if err == nil && claims != nil {
|
||||
err = claims.Validate(audience)
|
||||
err = claims.Validate(r.Context(), audience, oidcProvider)
|
||||
}
|
||||
if err != nil {
|
||||
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
@@ -70,21 +70,6 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
|
||||
return
|
||||
}
|
||||
|
||||
if oidcProvider != nil {
|
||||
// If OIDC Provider is configured, this token must be validated against it.
|
||||
if err := validateTokenWithOIDC(r.Context(), oidcProvider, token, audience); err != nil {
|
||||
klog.V(1).Infof("Authentication failed - OIDC token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
if serverURL == "" {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
|
||||
}
|
||||
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Scopes are likely to be used for authorization.
|
||||
scopes := claims.GetScopes()
|
||||
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
|
||||
@@ -138,6 +123,7 @@ var allSignatureAlgorithms = []jose.SignatureAlgorithm{
|
||||
|
||||
type JWTClaims struct {
|
||||
jwt.Claims
|
||||
Token string `json:"-"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
@@ -149,10 +135,21 @@ func (c *JWTClaims) GetScopes() []string {
|
||||
}
|
||||
|
||||
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
|
||||
func (c *JWTClaims) Validate(audience string) error {
|
||||
return c.Claims.Validate(jwt.Expected{
|
||||
AnyAudience: jwt.Audience{audience},
|
||||
})
|
||||
func (c *JWTClaims) Validate(ctx context.Context, audience string, provider *oidc.Provider) error {
|
||||
if err := c.Claims.Validate(jwt.Expected{AnyAudience: jwt.Audience{audience}}); err != nil {
|
||||
return fmt.Errorf("JWT token validation error: %v", err)
|
||||
}
|
||||
if provider != nil {
|
||||
verifier := provider.Verifier(&oidc.Config{
|
||||
ClientID: audience,
|
||||
})
|
||||
|
||||
_, err := verifier.Verify(ctx, c.Token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("OIDC token validation error: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseJWTClaims(token string) (*JWTClaims, error) {
|
||||
@@ -162,18 +159,6 @@ func ParseJWTClaims(token string) (*JWTClaims, error) {
|
||||
}
|
||||
claims := &JWTClaims{}
|
||||
err = tkn.UnsafeClaimsWithoutVerification(claims)
|
||||
claims.Token = token
|
||||
return claims, err
|
||||
}
|
||||
|
||||
func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {
|
||||
verifier := provider.Verifier(&oidc.Config{
|
||||
ClientID: audience,
|
||||
})
|
||||
|
||||
_, err := verifier.Verify(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("JWT token verification failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestJWTTokenValidate(t *testing.T) {
|
||||
t.Fatalf("expected no error for expired token parsing, got %v", err)
|
||||
}
|
||||
|
||||
err = claims.Validate("kubernetes-mcp-server")
|
||||
err = claims.Validate(t.Context(), "kubernetes-mcp-server", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for expired token, got nil")
|
||||
}
|
||||
@@ -130,7 +130,7 @@ func TestJWTTokenValidate(t *testing.T) {
|
||||
t.Fatalf("expected claims to be returned, got nil")
|
||||
}
|
||||
|
||||
err = claims.Validate("kubernetes-mcp-server")
|
||||
err = claims.Validate(t.Context(), "kubernetes-mcp-server", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for valid audience, got %v", err)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func TestJWTTokenValidate(t *testing.T) {
|
||||
t.Fatalf("expected claims to be returned, got nil")
|
||||
}
|
||||
|
||||
err = claims.Validate("missing-audience")
|
||||
err = claims.Validate(t.Context(), "missing-audience", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for token with wrong audience, got nil")
|
||||
}
|
||||
|
||||
@@ -127,6 +127,7 @@ func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *
|
||||
t.Fatalf("failed to generate private key for oidc: %v", err)
|
||||
}
|
||||
oidcServer := &oidctest.Server{
|
||||
Algorithms: []string{oidc.RS256, oidc.ES256},
|
||||
PublicKeys: []oidctest.PublicKey{
|
||||
{
|
||||
PublicKey: privateKey.Public(),
|
||||
@@ -470,8 +471,8 @@ func TestAuthorizationUnauthorized(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INVALID OIDC Authorization header logs error", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - OIDC token validation error") &&
|
||||
!strings.Contains(ctx.LogBuffer.String(), "JWT token verification failed: oidc: id token issued by a different provider") {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
|
||||
!strings.Contains(ctx.LogBuffer.String(), "OIDC token validation error: failed to verify signature") {
|
||||
t.Errorf("Expected log entry for OIDC validation error, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user