mirror of
https://github.com/containers/kubernetes-mcp-server.git
synced 2025-10-23 01:22:57 +03:00
fix(auth): delegate JWT parsing to github.com/go-jose/go-jose (189)
fix(auth): delegate JWT parsing to github.com/golang-jwt/jwt Signed-off-by: Marc Nuri <marc@marcnuri.com> --- fix(auth): delegate JWT parsing to go-jose Signed-off-by: Marc Nuri <marc@marcnuri.com> --- fix(auth): delegate JWT parsing to go-jose - review comment Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
2
go.mod
2
go.mod
@@ -6,6 +6,7 @@ require (
|
||||
github.com/BurntSushi/toml v1.5.0
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-jose/go-jose/v4 v4.0.5
|
||||
github.com/mark3labs/mcp-go v0.34.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/spf13/afero v1.14.0
|
||||
@@ -52,7 +53,6 @@ require (
|
||||
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
|
||||
github.com/go-errors/errors v1.4.2 // indirect
|
||||
github.com/go-gorp/gorp/v3 v3.1.0 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
github.com/go-openapi/jsonreference v0.20.2 // indirect
|
||||
|
||||
@@ -2,14 +2,13 @@ package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"k8s.io/klog/v2"
|
||||
|
||||
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
|
||||
@@ -55,7 +54,10 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
|
||||
// Validate the token offline for simple sanity check
|
||||
// Because missing expected audience and expired tokens must be
|
||||
// rejected already.
|
||||
claims, err := validateJWTToken(token, audience)
|
||||
claims, err := ParseJWTClaims(token)
|
||||
if err == nil && claims != nil {
|
||||
err = claims.Validate(audience)
|
||||
}
|
||||
if err != nil {
|
||||
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
@@ -117,11 +119,25 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
|
||||
}
|
||||
}
|
||||
|
||||
var allSignatureAlgorithms = []jose.SignatureAlgorithm{
|
||||
jose.EdDSA,
|
||||
jose.HS256,
|
||||
jose.HS384,
|
||||
jose.HS512,
|
||||
jose.RS256,
|
||||
jose.RS384,
|
||||
jose.RS512,
|
||||
jose.ES256,
|
||||
jose.ES384,
|
||||
jose.ES512,
|
||||
jose.PS256,
|
||||
jose.PS384,
|
||||
jose.PS512,
|
||||
}
|
||||
|
||||
type JWTClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Audience any `json:"aud"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
jwt.Claims
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
func (c *JWTClaims) GetScopes() []string {
|
||||
@@ -131,66 +147,21 @@ func (c *JWTClaims) GetScopes() []string {
|
||||
return strings.Fields(c.Scope)
|
||||
}
|
||||
|
||||
func (c *JWTClaims) ContainsAudience(audience string) bool {
|
||||
switch aud := c.Audience.(type) {
|
||||
case string:
|
||||
return aud == audience
|
||||
case []interface{}:
|
||||
for _, a := range aud {
|
||||
if str, ok := a.(string); ok && str == audience {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, a := range aud {
|
||||
if a == audience {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
|
||||
func (c *JWTClaims) Validate(audience string) error {
|
||||
return c.Claims.Validate(jwt.Expected{
|
||||
AnyAudience: jwt.Audience{audience},
|
||||
})
|
||||
}
|
||||
|
||||
// validateJWTToken validates basic JWT claims without signature verification and returns the claims
|
||||
func validateJWTToken(token, audience string) (*JWTClaims, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT token format")
|
||||
}
|
||||
|
||||
claims, err := parseJWTClaims(parts[1])
|
||||
func ParseJWTClaims(token string) (*JWTClaims, error) {
|
||||
tkn, err := jwt.ParseSigned(token, allSignatureAlgorithms)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %v", err)
|
||||
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
|
||||
}
|
||||
|
||||
if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt {
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
if !claims.ContainsAudience(audience) {
|
||||
return nil, fmt.Errorf("token audience mismatch: %v", claims.Audience)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func parseJWTClaims(payload string) (*JWTClaims, error) {
|
||||
// Add padding if needed
|
||||
if len(payload)%4 != 0 {
|
||||
payload += strings.Repeat("=", 4-len(payload)%4)
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %v", err)
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
claims := &JWTClaims{}
|
||||
err = tkn.UnsafeClaimsWithoutVerification(claims)
|
||||
return claims, err
|
||||
}
|
||||
|
||||
func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {
|
||||
|
||||
@@ -1,187 +1,222 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
)
|
||||
|
||||
func TestParseJWTClaims(t *testing.T) {
|
||||
t.Run("valid JWT payload", func(t *testing.T) {
|
||||
// Sample payload from a valid JWT
|
||||
payload := "eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxNzUxOTYzOTQ4LCJpYXQiOjE3NTE5NjAzNDgsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MTc1MTk2MDM0OCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9"
|
||||
const (
|
||||
// https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA // notsecret
|
||||
tokenBasicNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0In0.0363P6xGmWpU-O9TAVkcOd95lPXxhI-_k5NKbHGNQeL--B8XMAz2vC8hpKnyC6rKOGifRTSR2XNHx_5fjd7lEA" // notsecret
|
||||
// https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA // notsecret
|
||||
tokenBasicExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxLCJpYXQiOjAsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9.USsuGLsB_7MwG9i0__cFkVVZa0djtmQpc8Vwi56GrapAgVAcyTfmae3s83XMDP5AwcFnxhYxLCfiZWRJri6GTA" // notsecret
|
||||
// https://jwt.io/#token=eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ // notsecret
|
||||
tokenMultipleAudienceNotExpired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6Ijk4ZDU3YmUwNWI3ZjUzNWIwMzYyYjg2MDJhNTJlNGYxIn0.eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoyNTM0MDIyOTcxOTksImlhdCI6MCwiaXNzIjoiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJqdGkiOiI5OTIyMmQ1Ni0zNDBlLTRlYjYtODU4OC0yNjE0MTFmMzVkMjYiLCJrdWJlcm5ldGVzLmlvIjp7Im5hbWVzcGFjZSI6ImRlZmF1bHQiLCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoiZGVmYXVsdCIsInVpZCI6ImVhY2I2YWQyLTgwYjctNDE3OS04NDNkLTkyZWIxZTZiYmJhNiJ9fSwibmJmIjowLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6ZGVmYXVsdDpkZWZhdWx0Iiwic2NvcGUiOiJyZWFkIHdyaXRlIn0.vl5se9BuxoVDhvR7M5wGfkLoyMSYUiORMZVxl0CQ7jw3x53mZfGEkU_kkIVIl9Ui371qCCVVxdvuZPcAgbM6pQ" // notsecret
|
||||
)
|
||||
|
||||
claims, err := parseJWTClaims(payload)
|
||||
func TestParseJWTClaimsPayloadValid(t *testing.T) {
|
||||
basicClaims, err := ParseJWTClaims(tokenBasicNotExpired)
|
||||
t.Run("Is parseable", func(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if claims == nil {
|
||||
if basicClaims == nil {
|
||||
t.Fatal("expected claims, got nil")
|
||||
}
|
||||
|
||||
if claims.Issuer != "https://kubernetes.default.svc.cluster.local" {
|
||||
t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", claims.Issuer)
|
||||
})
|
||||
t.Run("Parses issuer", func(t *testing.T) {
|
||||
if basicClaims.Issuer != "https://kubernetes.default.svc.cluster.local" {
|
||||
t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", basicClaims.Issuer)
|
||||
}
|
||||
|
||||
})
|
||||
t.Run("Parses audience", func(t *testing.T) {
|
||||
expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"}
|
||||
for _, expected := range expectedAudiences {
|
||||
if !claims.ContainsAudience(expected) {
|
||||
if !basicClaims.Audience.Contains(expected) {
|
||||
t.Errorf("expected audience to contain %s", expected)
|
||||
}
|
||||
}
|
||||
|
||||
if claims.ExpiresAt != 1751963948 {
|
||||
t.Errorf("expected exp 1751963948, got %d", claims.ExpiresAt)
|
||||
})
|
||||
t.Run("Parses expiration", func(t *testing.T) {
|
||||
if *basicClaims.Expiry != jwt.NumericDate(253402297199) {
|
||||
t.Errorf("expected expiration 253402297199, got %d", basicClaims.Expiry)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("payload needs padding", func(t *testing.T) {
|
||||
// Create a payload that needs padding
|
||||
testClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
t.Run("Parses scope", func(t *testing.T) {
|
||||
scopeClaims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if scopeClaims == nil {
|
||||
t.Fatal("expected claims, got nil")
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(testClaims)
|
||||
// Create a payload without proper padding
|
||||
encodedWithoutPadding := strings.TrimRight(base64.URLEncoding.EncodeToString(jsonBytes), "=")
|
||||
scopes := scopeClaims.GetScopes()
|
||||
|
||||
claims, err := parseJWTClaims(encodedWithoutPadding)
|
||||
expectedScopes := []string{"read", "write"}
|
||||
if len(scopes) != len(expectedScopes) {
|
||||
t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(scopes))
|
||||
}
|
||||
for i, expectedScope := range expectedScopes {
|
||||
if scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("Parses expired token", func(t *testing.T) {
|
||||
expiredClaims, err := ParseJWTClaims(tokenBasicExpired)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if claims.Issuer != "test-issuer" {
|
||||
t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid base64 payload", func(t *testing.T) {
|
||||
invalidPayload := "invalid-base64!!!"
|
||||
|
||||
_, err := parseJWTClaims(invalidPayload)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid base64, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "failed to decode JWT payload") {
|
||||
t.Errorf("expected decode error message, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON payload", func(t *testing.T) {
|
||||
// Valid base64 but invalid JSON
|
||||
invalidJSON := base64.URLEncoding.EncodeToString([]byte("{invalid-json"))
|
||||
|
||||
_, err := parseJWTClaims(invalidJSON)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "failed to unmarshal JWT claims") {
|
||||
t.Errorf("expected unmarshal error message, got %v", err)
|
||||
if *expiredClaims.Expiry != jwt.NumericDate(1) {
|
||||
t.Errorf("expected expiration 1, got %d", basicClaims.Expiry)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateJWTToken(t *testing.T) {
|
||||
t.Run("invalid token format - not enough parts", func(t *testing.T) {
|
||||
invalidToken := "header.payload"
|
||||
func TestParseJWTClaimsPayloadInvalid(t *testing.T) {
|
||||
t.Run("invalid token segments", func(t *testing.T) {
|
||||
invalidToken := "header.payload.signature.extra"
|
||||
|
||||
_, err := validateJWTToken(invalidToken, "test")
|
||||
_, err := ParseJWTClaims(invalidToken)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid token format, got nil")
|
||||
t.Fatal("expected error for invalid token segments, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid JWT token format") {
|
||||
t.Errorf("expected format error message, got %v", err)
|
||||
if !strings.Contains(err.Error(), "compact JWS format must have three parts") {
|
||||
t.Errorf("expected invalid token segments error message, got %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("invalid base64 payload", func(t *testing.T) {
|
||||
invalidPayload := "invalid_base64" + tokenBasicNotExpired
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
// Create an expired token
|
||||
expiredClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "kubernetes-mcp-server",
|
||||
ExpiresAt: time.Now().Add(-time.Hour).Unix(),
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(expiredClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
expiredToken := "header." + payload + ".signature"
|
||||
|
||||
_, err := validateJWTToken(expiredToken, "kubernetes-mcp-server")
|
||||
_, err := ParseJWTClaims(invalidPayload)
|
||||
if err == nil {
|
||||
t.Error("expected error for expired token, got nil")
|
||||
t.Fatal("expected error for invalid base64, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "token expired") {
|
||||
if !strings.Contains(err.Error(), "illegal base64 data") {
|
||||
t.Errorf("expected decode error message, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTTokenValidate(t *testing.T) {
|
||||
t.Run("expired token returns error", func(t *testing.T) {
|
||||
claims, err := ParseJWTClaims(tokenBasicExpired)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for expired token parsing, got %v", err)
|
||||
}
|
||||
|
||||
err = claims.Validate("kubernetes-mcp-server")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for expired token, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "token is expired (exp)") {
|
||||
t.Errorf("expected expiration error message, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple audiences with correct one", func(t *testing.T) {
|
||||
// Create a token with multiple audiences including the correct one
|
||||
multiAudClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"other-audience", "kubernetes-mcp-server", "another-audience"},
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
Scope: "read write admin",
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(multiAudClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
multiAudToken := "header." + payload + ".signature"
|
||||
|
||||
claims, err := validateJWTToken(multiAudToken, "kubernetes-mcp-server")
|
||||
claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error for token with multiple audiences, got %v", err)
|
||||
t.Fatalf("expected no error for multiple audience token parsing, got %v", err)
|
||||
}
|
||||
if claims == nil {
|
||||
t.Error("expected claims to be returned, got nil")
|
||||
}
|
||||
if claims.Issuer != "test-issuer" {
|
||||
t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer)
|
||||
t.Fatalf("expected claims to be returned, got nil")
|
||||
}
|
||||
|
||||
// Test scope parsing
|
||||
err = claims.Validate("kubernetes-mcp-server")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for valid audience, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple audiences with mismatch returns error", func(t *testing.T) {
|
||||
claims, err := ParseJWTClaims(tokenMultipleAudienceNotExpired)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for multiple audience token parsing, got %v", err)
|
||||
}
|
||||
if claims == nil {
|
||||
t.Fatalf("expected claims to be returned, got nil")
|
||||
}
|
||||
|
||||
err = claims.Validate("missing-audience")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for token with wrong audience, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid audience claim (aud)") {
|
||||
t.Errorf("expected audience mismatch error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTClaimsGetScopes(t *testing.T) {
|
||||
t.Run("no scopes", func(t *testing.T) {
|
||||
claims, err := ParseJWTClaims(tokenBasicExpired)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for parsing token, got %v", err)
|
||||
}
|
||||
|
||||
if scopes := claims.GetScopes(); len(scopes) != 0 {
|
||||
t.Errorf("expected no scopes, got %d", len(scopes))
|
||||
}
|
||||
})
|
||||
t.Run("single scope", func(t *testing.T) {
|
||||
claims := &JWTClaims{
|
||||
Scope: "read",
|
||||
}
|
||||
scopes := claims.GetScopes()
|
||||
expectedScopes := []string{"read", "write", "admin"}
|
||||
expected := []string{"read"}
|
||||
|
||||
if len(scopes) != 1 {
|
||||
t.Errorf("expected 1 scope, got %d", len(scopes))
|
||||
}
|
||||
if scopes[0] != expected[0] {
|
||||
t.Errorf("expected scope 'read', got '%s'", scopes[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple scopes", func(t *testing.T) {
|
||||
claims := &JWTClaims{
|
||||
Scope: "read write admin",
|
||||
}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read", "write", "admin"}
|
||||
|
||||
if len(scopes) != 3 {
|
||||
t.Errorf("expected 3 scopes, got %d", len(scopes))
|
||||
}
|
||||
for i, expectedScope := range expectedScopes {
|
||||
|
||||
for i, expectedScope := range expected {
|
||||
if i >= len(scopes) || scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("audience mismatch", func(t *testing.T) {
|
||||
// Create a token with wrong audience
|
||||
wrongAudClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "wrong-audience",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
t.Run("scopes with extra whitespace", func(t *testing.T) {
|
||||
claims := &JWTClaims{
|
||||
Scope: " read write admin ",
|
||||
}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read", "write", "admin"}
|
||||
|
||||
if len(scopes) != 3 {
|
||||
t.Errorf("expected 3 scopes, got %d", len(scopes))
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(wrongAudClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
wrongAudToken := "header." + payload + ".signature"
|
||||
|
||||
_, err := validateJWTToken(wrongAudToken, "audience")
|
||||
if err == nil {
|
||||
t.Error("expected error for token with wrong audience, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "audience mismatch") {
|
||||
t.Errorf("expected audience mismatch error, got %v", err)
|
||||
for i, expectedScope := range expected {
|
||||
if i >= len(scopes) || scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -285,85 +320,3 @@ func TestAuthorizationMiddleware(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTClaimsGetScopes(t *testing.T) {
|
||||
t.Run("single scope", func(t *testing.T) {
|
||||
claims := &JWTClaims{Scope: "read"}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read"}
|
||||
|
||||
if len(scopes) != 1 {
|
||||
t.Errorf("expected 1 scope, got %d", len(scopes))
|
||||
}
|
||||
if scopes[0] != expected[0] {
|
||||
t.Errorf("expected scope 'read', got '%s'", scopes[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple scopes", func(t *testing.T) {
|
||||
claims := &JWTClaims{Scope: "read write admin"}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read", "write", "admin"}
|
||||
|
||||
if len(scopes) != 3 {
|
||||
t.Errorf("expected 3 scopes, got %d", len(scopes))
|
||||
}
|
||||
|
||||
for i, expectedScope := range expected {
|
||||
if i >= len(scopes) || scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("scopes with extra whitespace", func(t *testing.T) {
|
||||
claims := &JWTClaims{Scope: " read write admin "}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read", "write", "admin"}
|
||||
|
||||
if len(scopes) != 3 {
|
||||
t.Errorf("expected 3 scopes, got %d", len(scopes))
|
||||
}
|
||||
|
||||
for i, expectedScope := range expected {
|
||||
if i >= len(scopes) || scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTClaimsContainsAudience(t *testing.T) {
|
||||
t.Run("single string audience", func(t *testing.T) {
|
||||
claims := &JWTClaims{Audience: "test-audience"}
|
||||
|
||||
if !claims.ContainsAudience("test-audience") {
|
||||
t.Error("expected ContainsAudience to return true for matching audience")
|
||||
}
|
||||
|
||||
if claims.ContainsAudience("other-audience") {
|
||||
t.Error("expected ContainsAudience to return false for non-matching audience")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array audience", func(t *testing.T) {
|
||||
claims := &JWTClaims{Audience: []string{"aud1", "aud2", "aud3"}}
|
||||
|
||||
testCases := []struct {
|
||||
audience string
|
||||
expected bool
|
||||
}{
|
||||
{"aud1", true},
|
||||
{"aud2", true},
|
||||
{"aud3", true},
|
||||
{"aud4", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if claims.ContainsAudience(tc.audience) != tc.expected {
|
||||
t.Errorf("expected ContainsAudience(%s) to return %v", tc.audience, tc.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user