fix(auth): delegate JWT parsing to github.com/go-jose/go-jose (189)

fix(auth): delegate JWT parsing to github.com/golang-jwt/jwt

Signed-off-by: Marc Nuri <marc@marcnuri.com>
---
fix(auth): delegate JWT parsing to go-jose

Signed-off-by: Marc Nuri <marc@marcnuri.com>
---
fix(auth): delegate JWT parsing to go-jose - review comment

Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
Marc Nuri
2025-07-18 13:01:55 +02:00
committed by GitHub
parent 73e9e845c4
commit 775fa21bd1
3 changed files with 189 additions and 265 deletions

2
go.mod
View File

@@ -6,6 +6,7 @@ require (
github.com/BurntSushi/toml v1.5.0
github.com/coreos/go-oidc/v3 v3.14.1
github.com/fsnotify/fsnotify v1.9.0
github.com/go-jose/go-jose/v4 v4.0.5
github.com/mark3labs/mcp-go v0.34.0
github.com/pkg/errors v0.9.1
github.com/spf13/afero v1.14.0
@@ -52,7 +53,6 @@ require (
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-errors/errors v1.4.2 // indirect
github.com/go-gorp/gorp/v3 v3.1.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect

View File

@@ -2,14 +2,13 @@ package http
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"k8s.io/klog/v2"
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
@@ -55,7 +54,10 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
// Validate the token offline for simple sanity check
// Because missing expected audience and expired tokens must be
// rejected already.
claims, err := validateJWTToken(token, audience)
claims, err := ParseJWTClaims(token)
if err == nil && claims != nil {
err = claims.Validate(audience)
}
if err != nil {
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
@@ -117,11 +119,25 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
}
}
var allSignatureAlgorithms = []jose.SignatureAlgorithm{
jose.EdDSA,
jose.HS256,
jose.HS384,
jose.HS512,
jose.RS256,
jose.RS384,
jose.RS512,
jose.ES256,
jose.ES384,
jose.ES512,
jose.PS256,
jose.PS384,
jose.PS512,
}
type JWTClaims struct {
Issuer string `json:"iss"`
Audience any `json:"aud"`
ExpiresAt int64 `json:"exp"`
Scope string `json:"scope,omitempty"`
jwt.Claims
Scope string `json:"scope,omitempty"`
}
func (c *JWTClaims) GetScopes() []string {
@@ -131,66 +147,21 @@ func (c *JWTClaims) GetScopes() []string {
return strings.Fields(c.Scope)
}
func (c *JWTClaims) ContainsAudience(audience string) bool {
switch aud := c.Audience.(type) {
case string:
return aud == audience
case []interface{}:
for _, a := range aud {
if str, ok := a.(string); ok && str == audience {
return true
}
}
case []string:
for _, a := range aud {
if a == audience {
return true
}
}
}
return false
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
func (c *JWTClaims) Validate(audience string) error {
return c.Claims.Validate(jwt.Expected{
AnyAudience: jwt.Audience{audience},
})
}
// validateJWTToken validates basic JWT claims without signature verification and returns the claims
func validateJWTToken(token, audience string) (*JWTClaims, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT token format")
}
claims, err := parseJWTClaims(parts[1])
func ParseJWTClaims(token string) (*JWTClaims, error) {
tkn, err := jwt.ParseSigned(token, allSignatureAlgorithms)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT claims: %v", err)
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
}
if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt {
return nil, fmt.Errorf("token expired")
}
if !claims.ContainsAudience(audience) {
return nil, fmt.Errorf("token audience mismatch: %v", claims.Audience)
}
return claims, nil
}
func parseJWTClaims(payload string) (*JWTClaims, error) {
// Add padding if needed
if len(payload)%4 != 0 {
payload += strings.Repeat("=", 4-len(payload)%4)
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
}
return &claims, nil
claims := &JWTClaims{}
err = tkn.UnsafeClaimsWithoutVerification(claims)
return claims, err
}
func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {

View File

@@ -1,187 +1,222 @@
package http
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-jose/go-jose/v4/jwt"
)
func TestParseJWTClaims(t *testing.T) {
t.Run("valid JWT payload", func(t *testing.T) {
// Sample payload from a valid JWT
payload := "eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxNzUxOTYzOTQ4LCJpYXQiOjE3NTE5NjAzNDgsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MTc1MTk2MDM0OCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9"
const (
// https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA // notsecret
tokenBasicNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA" // notsecret
// https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA // notsecret
tokenBasicExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA" // notsecret
// https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ // notsecret
tokenMultipleAudienceNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ" // notsecret
)
claims, err := parseJWTClaims(payload)
func TestParseJWTClaimsPayloadValid(t *testing.T) {
basicClaims, err := ParseJWTClaims(tokenBasicNotExpired)
t.Run("Is parseable", func(t *testing.T) {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if claims == nil {
if basicClaims == nil {
t.Fatal("expected claims, got nil")
}
if claims.Issuer != "https://kubernetes.default.svc.cluster.local" {
t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", claims.Issuer)
})
t.Run("Parses issuer", func(t *testing.T) {
if basicClaims.Issuer != "https://kubernetes.default.svc.cluster.local" {
t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", basicClaims.Issuer)
}
})
t.Run("Parses audience", func(t *testing.T) {
expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"}
for _, expected := range expectedAudiences {
if !claims.ContainsAudience(expected) {
if !basicClaims.Audience.Contains(expected) {
t.Errorf("expected audience to contain %s", expected)
}
}
if claims.ExpiresAt != 1751963948 {
t.Errorf("expected exp 1751963948, got %d", claims.ExpiresAt)
})
t.Run("Parses expiration", func(t *testing.T) {
if *basicClaims.Expiry != jwt.NumericDate(253402297199) {
t.Errorf("expected expiration 253402297199, got %d", basicClaims.Expiry)
}
})
t.Run("payload needs padding", func(t *testing.T) {
// Create a payload that needs padding
testClaims := JWTClaims{
Issuer: "test-issuer",
Audience: "test-audience",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
t.Run("Parses scope", func(t *testing.T) {
scopeClaims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if scopeClaims == nil {
t.Fatal("expected claims, got nil")
}
jsonBytes, _ := json.Marshal(testClaims)
// Create a payload without proper padding
encodedWithoutPadding := strings.TrimRight(base64.URLEncoding.EncodeToString(jsonBytes), "=")
scopes := scopeClaims.GetScopes()
claims, err := parseJWTClaims(encodedWithoutPadding)
expectedScopes := []string{"read", "write"}
if len(scopes) != len(expectedScopes) {
t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(scopes))
}
for i, expectedScope := range expectedScopes {
if scopes[i] != expectedScope {
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
}
}
})
t.Run("Parses expired token", func(t *testing.T) {
expiredClaims, err := ParseJWTClaims(tokenBasicExpired)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if claims.Issuer != "test-issuer" {
t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer)
}
})
t.Run("invalid base64 payload", func(t *testing.T) {
invalidPayload := "invalid-base64!!!"
_, err := parseJWTClaims(invalidPayload)
if err == nil {
t.Error("expected error for invalid base64, got nil")
}
if !strings.Contains(err.Error(), "failed to decode JWT payload") {
t.Errorf("expected decode error message, got %v", err)
}
})
t.Run("invalid JSON payload", func(t *testing.T) {
// Valid base64 but invalid JSON
invalidJSON := base64.URLEncoding.EncodeToString([]byte("{invalid-json"))
_, err := parseJWTClaims(invalidJSON)
if err == nil {
t.Error("expected error for invalid JSON, got nil")
}
if !strings.Contains(err.Error(), "failed to unmarshal JWT claims") {
t.Errorf("expected unmarshal error message, got %v", err)
if *expiredClaims.Expiry != jwt.NumericDate(1) {
t.Errorf("expected expiration 1, got %d", basicClaims.Expiry)
}
})
}
func TestValidateJWTToken(t *testing.T) {
t.Run("invalid token format - not enough parts", func(t *testing.T) {
invalidToken := "header.payload"
func TestParseJWTClaimsPayloadInvalid(t *testing.T) {
t.Run("invalid token segments", func(t *testing.T) {
invalidToken := "header.payload.signature.extra"
_, err := validateJWTToken(invalidToken, "test")
_, err := ParseJWTClaims(invalidToken)
if err == nil {
t.Error("expected error for invalid token format, got nil")
t.Fatal("expected error for invalid token segments, got nil")
}
if !strings.Contains(err.Error(), "invalid JWT token format") {
t.Errorf("expected format error message, got %v", err)
if !strings.Contains(err.Error(), "compact JWS format must have three parts") {
t.Errorf("expected invalid token segments error message, got %v", err)
}
})
t.Run("invalid base64 payload", func(t *testing.T) {
invalidPayload := "invalid_base64" + tokenBasicNotExpired
t.Run("expired token", func(t *testing.T) {
// Create an expired token
expiredClaims := JWTClaims{
Issuer: "test-issuer",
Audience: "kubernetes-mcp-server",
ExpiresAt: time.Now().Add(-time.Hour).Unix(),
}
jsonBytes, _ := json.Marshal(expiredClaims)
payload := base64.URLEncoding.EncodeToString(jsonBytes)
expiredToken := "header." + payload + ".signature"
_, err := validateJWTToken(expiredToken, "kubernetes-mcp-server")
_, err := ParseJWTClaims(invalidPayload)
if err == nil {
t.Error("expected error for expired token, got nil")
t.Fatal("expected error for invalid base64, got nil")
}
if !strings.Contains(err.Error(), "token expired") {
if !strings.Contains(err.Error(), "illegal base64 data") {
t.Errorf("expected decode error message, got %v", err)
}
})
}
func TestJWTTokenValidate(t *testing.T) {
t.Run("expired token returns error", func(t *testing.T) {
claims, err := ParseJWTClaims(tokenBasicExpired)
if err != nil {
t.Fatalf("expected no error for expired token parsing, got %v", err)
}
err = claims.Validate("kubernetes-mcp-server")
if err == nil {
t.Fatalf("expected error for expired token, got nil")
}
if !strings.Contains(err.Error(), "token is expired (exp)") {
t.Errorf("expected expiration error message, got %v", err)
}
})
t.Run("multiple audiences with correct one", func(t *testing.T) {
// Create a token with multiple audiences including the correct one
multiAudClaims := JWTClaims{
Issuer: "test-issuer",
Audience: []string{"other-audience", "kubernetes-mcp-server", "another-audience"},
ExpiresAt: time.Now().Add(time.Hour).Unix(),
Scope: "read write admin",
}
jsonBytes, _ := json.Marshal(multiAudClaims)
payload := base64.URLEncoding.EncodeToString(jsonBytes)
multiAudToken := "header." + payload + ".signature"
claims, err := validateJWTToken(multiAudToken, "kubernetes-mcp-server")
claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired)
if err != nil {
t.Errorf("expected no error for token with multiple audiences, got %v", err)
t.Fatalf("expected no error for multiple audience token parsing, got %v", err)
}
if claims == nil {
t.Error("expected claims to be returned, got nil")
}
if claims.Issuer != "test-issuer" {
t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer)
t.Fatalf("expected claims to be returned, got nil")
}
// Test scope parsing
err = claims.Validate("kubernetes-mcp-server")
if err != nil {
t.Fatalf("expected no error for valid audience, got %v", err)
}
})
t.Run("multiple audiences with mismatch returns error", func(t *testing.T) {
claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired)
if err != nil {
t.Fatalf("expected no error for multiple audience token parsing, got %v", err)
}
if claims == nil {
t.Fatalf("expected claims to be returned, got nil")
}
err = claims.Validate("missing-audience")
if err == nil {
t.Fatalf("expected error for token with wrong audience, got nil")
}
if !strings.Contains(err.Error(), "invalid audience claim (aud)") {
t.Errorf("expected audience mismatch error, got %v", err)
}
})
}
func TestJWTClaimsGetScopes(t *testing.T) {
t.Run("no scopes", func(t *testing.T) {
claims, err := ParseJWTClaims(tokenBasicExpired)
if err != nil {
t.Fatalf("expected no error for parsing token, got %v", err)
}
if scopes := claims.GetScopes(); len(scopes) != 0 {
t.Errorf("expected no scopes, got %d", len(scopes))
}
})
t.Run("single scope", func(t *testing.T) {
claims := &JWTClaims{
Scope: "read",
}
scopes := claims.GetScopes()
expectedScopes := []string{"read", "write", "admin"}
expected := []string{"read"}
if len(scopes) != 1 {
t.Errorf("expected 1 scope, got %d", len(scopes))
}
if scopes[0] != expected[0] {
t.Errorf("expected scope 'read', got '%s'", scopes[0])
}
})
t.Run("multiple scopes", func(t *testing.T) {
claims := &JWTClaims{
Scope: "read write admin",
}
scopes := claims.GetScopes()
expected := []string{"read", "write", "admin"}
if len(scopes) != 3 {
t.Errorf("expected 3 scopes, got %d", len(scopes))
}
for i, expectedScope := range expectedScopes {
for i, expectedScope := range expected {
if i >= len(scopes) || scopes[i] != expectedScope {
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
}
}
})
t.Run("audience mismatch", func(t *testing.T) {
// Create a token with wrong audience
wrongAudClaims := JWTClaims{
Issuer: "test-issuer",
Audience: "wrong-audience",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
t.Run("scopes with extra whitespace", func(t *testing.T) {
claims := &JWTClaims{
Scope: " read write admin ",
}
scopes := claims.GetScopes()
expected := []string{"read", "write", "admin"}
if len(scopes) != 3 {
t.Errorf("expected 3 scopes, got %d", len(scopes))
}
jsonBytes, _ := json.Marshal(wrongAudClaims)
payload := base64.URLEncoding.EncodeToString(jsonBytes)
wrongAudToken := "header." + payload + ".signature"
_, err := validateJWTToken(wrongAudToken, "audience")
if err == nil {
t.Error("expected error for token with wrong audience, got nil")
}
if !strings.Contains(err.Error(), "audience mismatch") {
t.Errorf("expected audience mismatch error, got %v", err)
for i, expectedScope := range expected {
if i >= len(scopes) || scopes[i] != expectedScope {
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
}
}
})
}
@@ -285,85 +320,3 @@ func TestAuthorizationMiddleware(t *testing.T) {
}
})
}
func TestJWTClaimsGetScopes(t *testing.T) {
t.Run("single scope", func(t *testing.T) {
claims := &JWTClaims{Scope: "read"}
scopes := claims.GetScopes()
expected := []string{"read"}
if len(scopes) != 1 {
t.Errorf("expected 1 scope, got %d", len(scopes))
}
if scopes[0] != expected[0] {
t.Errorf("expected scope 'read', got '%s'", scopes[0])
}
})
t.Run("multiple scopes", func(t *testing.T) {
claims := &JWTClaims{Scope: "read write admin"}
scopes := claims.GetScopes()
expected := []string{"read", "write", "admin"}
if len(scopes) != 3 {
t.Errorf("expected 3 scopes, got %d", len(scopes))
}
for i, expectedScope := range expected {
if i >= len(scopes) || scopes[i] != expectedScope {
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
}
}
})
t.Run("scopes with extra whitespace", func(t *testing.T) {
claims := &JWTClaims{Scope: " read write admin "}
scopes := claims.GetScopes()
expected := []string{"read", "write", "admin"}
if len(scopes) != 3 {
t.Errorf("expected 3 scopes, got %d", len(scopes))
}
for i, expectedScope := range expected {
if i >= len(scopes) || scopes[i] != expectedScope {
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
}
}
})
}
func TestJWTClaimsContainsAudience(t *testing.T) {
t.Run("single string audience", func(t *testing.T) {
claims := &JWTClaims{Audience: "test-audience"}
if !claims.ContainsAudience("test-audience") {
t.Error("expected ContainsAudience to return true for matching audience")
}
if claims.ContainsAudience("other-audience") {
t.Error("expected ContainsAudience to return false for non-matching audience")
}
})
t.Run("array audience", func(t *testing.T) {
claims := &JWTClaims{Audience: []string{"aud1", "aud2", "aud3"}}
testCases := []struct {
audience string
expected bool
}{
{"aud1", true},
{"aud2", true},
{"aud3", true},
{"aud4", false},
{"", false},
}
for _, tc := range testCases {
if claims.ContainsAudience(tc.audience) != tc.expected {
t.Errorf("expected ContainsAudience(%s) to return %v", tc.audience, tc.expected)
}
}
})
}