test(auth): complete test scenarios for raw token and oidc (#248)

Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
Marc Nuri
2025-08-07 16:04:12 +03:00
committed by GitHub
parent 43744f2978
commit cfc42b3bd3
7 changed files with 155 additions and 33 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
authenticationapiv1 "k8s.io/api/authentication/v1"
"k8s.io/klog/v2"
"k8s.io/utils/strings/slices"
@@ -19,8 +20,37 @@ const (
Audience = "mcp-server"
)
// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API
func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *oidc.Provider, mcpServer *mcp.Server) func(http.Handler) http.Handler {
type KubernetesApiTokenVerifier interface {
// KubernetesApiVerifyToken TODO: clarify proper implementation
KubernetesApiVerifyToken(ctx context.Context, token, audience string) (*authenticationapiv1.UserInfo, []string, error)
}
// AuthorizationMiddleware validates the OAuth flow for protected resources.
//
// The flow is skipped for unprotected resources, such as health checks and well-known endpoints.
//
// There are several auth scenarios
//
// 1. requireOAuth is false:
//
// - The OAuth flow is skipped, and the server is effectively unprotected.
// - The request is passed to the next handler without any validation.
//
// see TestAuthorizationRequireOAuthFalse
//
// 2. requireOAuth is set to true, server is protected:
//
// 2.1. Raw Token Validation (oidcProvider is nil):
// - The token is validated offline for basic sanity checks (audience and expiration).
// - The token is then used against the Kubernetes API Server for TokenReview.
//
// 2.2. OIDC Provider Validation (oidcProvider is not nil):
// - The token is validated offline for basic sanity checks (audience and expiration).
// - The token is then validated against the OIDC Provider.
// - The token is then used against the Kubernetes API Server for TokenReview.
//
// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx):
func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) {
@@ -38,20 +68,13 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
if serverURL == "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience))
} else {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="missing_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
}
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience))
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
// Validate the token offline for simple sanity check
// Because missing expected audience and expired tokens must be
// rejected already.
claims, err := ParseJWTClaims(token)
if err == nil && claims != nil {
err = claims.Validate(r.Context(), audience, oidcProvider)
@@ -59,11 +82,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
if err != nil {
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
if serverURL == "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
} else {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
}
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
return
}
@@ -85,15 +104,11 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
// 2. b. If this is not the only token in the headers, the token in here is used
// only for authentication and authorization. Therefore, we need to send TokenReview request
// with the other token in the headers (TODO: still need to validate aud and exp of this token separately).
_, _, err = mcpServer.VerifyTokenAPIServer(r.Context(), token, audience)
_, _, err = verifier.KubernetesApiVerifyToken(r.Context(), token, audience)
if err != nil {
klog.V(1).Infof("Authentication failed - API Server token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
if serverURL == "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
} else {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
}
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
return
}

View File

@@ -28,7 +28,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
mux := http.NewServeMux()
wrappedMux := RequestMiddleware(
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, oidcProvider, mcpServer)(mux),
AuthorizationMiddleware(staticConfig.RequireOAuth, oidcProvider, mcpServer)(mux),
)
httpServer := &http.Server{

View File

@@ -19,11 +19,11 @@ import (
"testing"
"time"
"github.com/containers/kubernetes-mcp-server/internal/test"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/coreos/go-oidc/v3/oidc/oidctest"
"golang.org/x/sync/errgroup"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/client-go/tools/clientcmd/api"
"k8s.io/klog/v2"
"k8s.io/klog/v2/textlogger"
@@ -33,6 +33,7 @@ import (
type httpContext struct {
klogState klog.State
mockServer *test.MockServer
LogBuffer bytes.Buffer
HttpAddress string // HTTP server address
timeoutCancel context.CancelFunc // Release resources if test completes before the timeout
@@ -42,21 +43,31 @@ type httpContext struct {
OidcProvider *oidc.Provider
}
const tokenReviewSuccessful = `
{
"kind": "TokenReview",
"apiVersion": "authentication.k8s.io/v1",
"spec": {"token": "valid-token"},
"status": {
"authenticated": true,
"user": {
"username": "test-user",
"groups": ["system:authenticated"]
}
}
}`
func (c *httpContext) beforeEach(t *testing.T) {
t.Helper()
http.DefaultClient.Timeout = 10 * time.Second
if c.StaticConfig == nil {
c.StaticConfig = &config.StaticConfig{}
}
c.mockServer = test.NewMockServer()
// Fake Kubernetes configuration
fakeConfig := api.NewConfig()
fakeConfig.Clusters["fake"] = api.NewCluster()
fakeConfig.Clusters["fake"].Server = "https://example.com"
fakeConfig.Contexts["fake-context"] = api.NewContext()
fakeConfig.Contexts["fake-context"].Cluster = "fake"
fakeConfig.CurrentContext = "fake-context"
mockKubeConfig := c.mockServer.KubeConfig()
kubeConfig := filepath.Join(t.TempDir(), "config")
_ = clientcmd.WriteToFile(*fakeConfig, kubeConfig)
_ = clientcmd.WriteToFile(*mockKubeConfig, kubeConfig)
_ = os.Setenv("KUBECONFIG", kubeConfig)
// Capture logging
c.klogState = klog.CaptureState()
@@ -100,6 +111,7 @@ func (c *httpContext) beforeEach(t *testing.T) {
func (c *httpContext) afterEach(t *testing.T) {
t.Helper()
c.mockServer.Close()
c.StopServer()
err := c.WaitForShutdown()
if err != nil {
@@ -546,3 +558,81 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
}
// TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required.
func TestAuthorizationRequireOAuthFalse(t *testing.T) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run("Protected resource with MISSING Authorization header returns 200 - OK)", func(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
}
})
})
}
func TestAuthorizationRawToken(t *testing.T) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tokenReviewSuccessful))
return
}
}))
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run("Protected resource with VALID Authorization header returns 200 - OK", func(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
}
})
})
}
func TestAuthorizationOidcToken(t *testing.T) {
key, oidcProvider, httpServer := NewOidcTestServer(t)
t.Cleanup(httpServer.Close)
rawClaims := `{
"iss": "` + httpServer.URL + `",
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
"aud": "mcp-server"
}`
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tokenReviewSuccessful))
return
}
}))
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+validOidcToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run("Protected resource with VALID OIDC Authorization header returns 200 - OK", func(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
}
})
})
}

View File

@@ -3,6 +3,7 @@ package kubernetes
import (
"context"
"fmt"
authenticationv1api "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

View File

@@ -122,9 +122,9 @@ func (s *Server) ServeHTTP(httpServer *http.Server) *server.StreamableHTTPServer
return server.NewStreamableHTTPServer(s.server, options...)
}
// VerifyTokenAPIServer verifies the given token with the audience by
// KubernetesApiVerifyToken verifies the given token with the audience by
// sending an TokenReview request to API Server.
func (s *Server) VerifyTokenAPIServer(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) {
func (s *Server) KubernetesApiVerifyToken(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) {
if s.k == nil {
return nil, nil, fmt.Errorf("kubernetes manager is not initialized")
}