test(auth): complete test suite for unauthorized scenarios (#220)

Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
Marc Nuri
2025-07-29 13:32:31 +02:00
committed by GitHub
parent aa14e31eba
commit 1f670ebec6
3 changed files with 285 additions and 152 deletions

View File

@@ -41,9 +41,9 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
if serverURL == "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience))
} else {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="missing_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
}
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
return
@@ -103,7 +103,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
// with the other token in the headers (TODO: still need to validate aud and exp of this token separately).
_, _, err = mcpServer.VerifyTokenAPIServer(r.Context(), token, audience)
if err != nil {
klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
klog.V(1).Infof("Authentication failed - API Server token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
if serverURL == "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))

View File

@@ -1,8 +1,6 @@
package http
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
@@ -220,103 +218,3 @@ func TestJWTClaimsGetScopes(t *testing.T) {
}
})
}
func TestAuthorizationMiddleware(t *testing.T) {
// Create a mock handler
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
t.Run("OAuth disabled - passes through", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth disabled
middleware := AuthorizationMiddleware(false, "", nil, nil)
wrappedHandler := middleware(handler)
// Create request without authorization header
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if !handlerCalled {
t.Error("expected handler to be called when OAuth is disabled")
}
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
})
t.Run("healthz endpoint - passes through", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth enabled
middleware := AuthorizationMiddleware(true, "", nil, nil)
wrappedHandler := middleware(handler)
// Create request to healthz endpoint
req := httptest.NewRequest("GET", "/healthz", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if !handlerCalled {
t.Error("expected handler to be called for healthz endpoint")
}
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
})
t.Run("OAuth enabled - missing token", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth enabled
middleware := AuthorizationMiddleware(true, "", nil, nil)
wrappedHandler := middleware(handler)
// Create request without authorization header
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if handlerCalled {
t.Error("expected handler NOT to be called when token is missing")
}
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "Bearer token required") {
t.Errorf("expected bearer token error message, got %s", w.Body.String())
}
})
t.Run("OAuth enabled - invalid token format", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth enabled
middleware := AuthorizationMiddleware(true, "", nil, nil)
wrappedHandler := middleware(handler)
// Create request with invalid bearer token
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer invalid-token")
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if handlerCalled {
t.Error("expected handler NOT to be called when token is invalid")
}
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "Invalid token") {
t.Errorf("expected invalid token error message, got %s", w.Body.String())
}
})
}

View File

@@ -4,17 +4,23 @@ import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"flag"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/coreos/go-oidc/v3/oidc/oidctest"
"golang.org/x/sync/errgroup"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/client-go/tools/clientcmd/api"
@@ -26,17 +32,22 @@ import (
)
type httpContext struct {
t *testing.T
klogState klog.State
logBuffer bytes.Buffer
httpAddress string // HTTP server address
LogBuffer bytes.Buffer
HttpAddress string // HTTP server address
timeoutCancel context.CancelFunc // Release resources if test completes before the timeout
stopServer context.CancelFunc
waitForShutdown func() error
StopServer context.CancelFunc
WaitForShutdown func() error
StaticConfig *config.StaticConfig
OidcProvider *oidc.Provider
}
func (c *httpContext) beforeEach() {
func (c *httpContext) beforeEach(t *testing.T) {
t.Helper()
http.DefaultClient.Timeout = 10 * time.Second
if c.StaticConfig == nil {
c.StaticConfig = &config.StaticConfig{}
}
// Fake Kubernetes configuration
fakeConfig := api.NewConfig()
fakeConfig.Clusters["fake"] = api.NewCluster()
@@ -44,7 +55,7 @@ func (c *httpContext) beforeEach() {
fakeConfig.Contexts["fake-context"] = api.NewContext()
fakeConfig.Contexts["fake-context"].Cluster = "fake"
fakeConfig.CurrentContext = "fake-context"
kubeConfig := filepath.Join(c.t.TempDir(), "config")
kubeConfig := filepath.Join(t.TempDir(), "config")
_ = clientcmd.WriteToFile(*fakeConfig, kubeConfig)
_ = os.Setenv("KUBECONFIG", kubeConfig)
// Capture logging
@@ -52,33 +63,33 @@ func (c *httpContext) beforeEach() {
flags := flag.NewFlagSet("test", flag.ContinueOnError)
klog.InitFlags(flags)
_ = flags.Set("v", "5")
klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(5), textlogger.Output(&c.logBuffer))))
klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(5), textlogger.Output(&c.LogBuffer))))
// Start server in random port
ln, err := net.Listen("tcp", "0.0.0.0:0")
if err != nil {
c.t.Fatalf("Failed to find random port for HTTP server: %v", err)
t.Fatalf("Failed to find random port for HTTP server: %v", err)
}
c.httpAddress = ln.Addr().String()
c.HttpAddress = ln.Addr().String()
if randomPortErr := ln.Close(); randomPortErr != nil {
c.t.Fatalf("Failed to close random port listener: %v", randomPortErr)
t.Fatalf("Failed to close random port listener: %v", randomPortErr)
}
staticConfig := &config.StaticConfig{Port: fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port)}
c.StaticConfig.Port = fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port)
mcpServer, err := mcp.NewServer(mcp.Configuration{
Profile: mcp.Profiles[0],
StaticConfig: staticConfig,
StaticConfig: c.StaticConfig,
})
if err != nil {
c.t.Fatalf("Failed to create MCP server: %v", err)
t.Fatalf("Failed to create MCP server: %v", err)
}
var timeoutCtx, cancelCtx context.Context
timeoutCtx, c.timeoutCancel = context.WithTimeout(c.t.Context(), 10*time.Second)
timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second)
group, gc := errgroup.WithContext(timeoutCtx)
cancelCtx, c.stopServer = context.WithCancel(gc)
group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig, nil) })
c.waitForShutdown = group.Wait
cancelCtx, c.StopServer = context.WithCancel(gc)
group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider) })
c.WaitForShutdown = group.Wait
// Wait for HTTP server to start (using net)
for i := 0; i < 10; i++ {
conn, err := net.Dial("tcp", c.httpAddress)
conn, err := net.Dial("tcp", c.HttpAddress)
if err == nil {
_ = conn.Close()
break
@@ -87,11 +98,12 @@ func (c *httpContext) beforeEach() {
}
}
func (c *httpContext) afterEach() {
c.stopServer()
err := c.waitForShutdown()
func (c *httpContext) afterEach(t *testing.T) {
t.Helper()
c.StopServer()
err := c.WaitForShutdown()
if err != nil {
c.t.Errorf("HTTP server did not shut down gracefully: %v", err)
t.Errorf("HTTP server did not shut down gracefully: %v", err)
}
c.timeoutCancel()
c.klogState.Restore()
@@ -99,34 +111,61 @@ func (c *httpContext) afterEach() {
}
func testCase(t *testing.T, test func(c *httpContext)) {
ctx := &httpContext{t: t}
ctx.beforeEach()
t.Cleanup(ctx.afterEach)
test(ctx)
testCaseWithContext(t, &httpContext{}, test)
}
func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpContext)) {
httpCtx.beforeEach(t)
t.Cleanup(func() { httpCtx.afterEach(t) })
test(httpCtx)
}
func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *oidc.Provider, httpServer *httptest.Server) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate private key for oidc: %v", err)
}
oidcServer := &oidctest.Server{
PublicKeys: []oidctest.PublicKey{
{
PublicKey: privateKey.Public(),
KeyID: "test-oidc-key-id",
Algorithm: oidc.RS256,
},
},
}
httpServer = httptest.NewServer(oidcServer)
oidcServer.SetIssuer(httpServer.URL)
oidcProvider, err = oidc.NewProvider(t.Context(), httpServer.URL)
if err != nil {
t.Fatalf("failed to create OIDC provider: %v", err)
}
return
}
func TestGracefulShutdown(t *testing.T) {
testCase(t, func(ctx *httpContext) {
ctx.stopServer()
err := ctx.waitForShutdown()
ctx.StopServer()
err := ctx.WaitForShutdown()
t.Run("Stops gracefully", func(t *testing.T) {
if err != nil {
t.Errorf("Expected graceful shutdown, but got error: %v", err)
}
})
t.Run("Stops on context cancel", func(t *testing.T) {
if !strings.Contains(ctx.logBuffer.String(), "Context cancelled, initiating graceful shutdown") {
t.Errorf("Context cancelled, initiating graceful shutdown, got: %s", ctx.logBuffer.String())
if !strings.Contains(ctx.LogBuffer.String(), "Context cancelled, initiating graceful shutdown") {
t.Errorf("Context cancelled, initiating graceful shutdown, got: %s", ctx.LogBuffer.String())
}
})
t.Run("Starts server shutdown", func(t *testing.T) {
if !strings.Contains(ctx.logBuffer.String(), "Shutting down HTTP server gracefully") {
t.Errorf("Expected graceful shutdown log, got: %s", ctx.logBuffer.String())
if !strings.Contains(ctx.LogBuffer.String(), "Shutting down HTTP server gracefully") {
t.Errorf("Expected graceful shutdown log, got: %s", ctx.LogBuffer.String())
}
})
t.Run("Server shutdown completes", func(t *testing.T) {
if !strings.Contains(ctx.logBuffer.String(), "HTTP server shutdown complete") {
t.Errorf("Expected HTTP server shutdown completed log, got: %s", ctx.logBuffer.String())
if !strings.Contains(ctx.LogBuffer.String(), "HTTP server shutdown complete") {
t.Errorf("Expected HTTP server shutdown completed log, got: %s", ctx.LogBuffer.String())
}
})
})
@@ -134,7 +173,7 @@ func TestGracefulShutdown(t *testing.T) {
func TestSseTransport(t *testing.T) {
testCase(t, func(ctx *httpContext) {
sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.httpAddress))
sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.HttpAddress))
t.Cleanup(func() { _ = sseResp.Body.Close() })
t.Run("Exposes SSE endpoint at /sse", func(t *testing.T) {
if sseErr != nil {
@@ -167,7 +206,7 @@ func TestSseTransport(t *testing.T) {
}
})
messageResp, messageErr := http.Post(
fmt.Sprintf("http://%s/message?sessionId=%s", ctx.httpAddress, strings.TrimSpace(endpoint[25:])),
fmt.Sprintf("http://%s/message?sessionId=%s", ctx.HttpAddress, strings.TrimSpace(endpoint[25:])),
"application/json",
bytes.NewBufferString("{}"),
)
@@ -185,7 +224,7 @@ func TestSseTransport(t *testing.T) {
func TestStreamableHttpTransport(t *testing.T) {
testCase(t, func(ctx *httpContext) {
mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.httpAddress))
mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
t.Cleanup(func() { _ = mcpGetResp.Body.Close() })
t.Run("Exposes MCP GET endpoint at /mcp", func(t *testing.T) {
if mcpGetErr != nil {
@@ -200,7 +239,7 @@ func TestStreamableHttpTransport(t *testing.T) {
t.Errorf("Expected Content-Type text/event-stream (GET), got %s", mcpGetResp.Header.Get("Content-Type"))
}
})
mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.httpAddress), "application/json", bytes.NewBufferString("{}"))
mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), "application/json", bytes.NewBufferString("{}"))
t.Cleanup(func() { _ = mcpPostResp.Body.Close() })
t.Run("Exposes MCP POST endpoint at /mcp", func(t *testing.T) {
if mcpPostErr != nil {
@@ -221,7 +260,7 @@ func TestStreamableHttpTransport(t *testing.T) {
func TestHealthCheck(t *testing.T) {
testCase(t, func(ctx *httpContext) {
t.Run("Exposes health check endpoint at /healthz", func(t *testing.T) {
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.httpAddress))
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get health check endpoint: %v", err)
}
@@ -231,11 +270,24 @@ func TestHealthCheck(t *testing.T) {
}
})
})
// Health exposed even when require Authorization
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get health check endpoint with OAuth: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run("Health check with OAuth returns HTTP 200 OK", func(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
}
})
})
}
func TestWellKnownOAuthProtectedResource(t *testing.T) {
testCase(t, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.httpAddress))
resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress))
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run("Exposes .well-known/oauth-protected-resource endpoint", func(t *testing.T) {
if err != nil {
@@ -255,17 +307,17 @@ func TestWellKnownOAuthProtectedResource(t *testing.T) {
func TestMiddlewareLogging(t *testing.T) {
testCase(t, func(ctx *httpContext) {
_, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.httpAddress))
_, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress))
t.Run("Logs HTTP requests and responses", func(t *testing.T) {
if !strings.Contains(ctx.logBuffer.String(), "GET /.well-known/oauth-protected-resource 200") {
t.Errorf("Expected log entry for GET /.well-known/oauth-protected-resource, got: %s", ctx.logBuffer.String())
if !strings.Contains(ctx.LogBuffer.String(), "GET /.well-known/oauth-protected-resource 200") {
t.Errorf("Expected log entry for GET /.well-known/oauth-protected-resource, got: %s", ctx.LogBuffer.String())
}
})
t.Run("Logs HTTP request duration", func(t *testing.T) {
expected := `"GET /.well-known/oauth-protected-resource 200 (.+)"`
m := regexp.MustCompile(expected).FindStringSubmatch(ctx.logBuffer.String())
m := regexp.MustCompile(expected).FindStringSubmatch(ctx.LogBuffer.String())
if len(m) != 2 {
t.Fatalf("Expected log entry to contain duration, got %s", ctx.logBuffer.String())
t.Fatalf("Expected log entry to contain duration, got %s", ctx.LogBuffer.String())
}
duration, err := time.ParseDuration(m[1])
if err != nil {
@@ -276,5 +328,188 @@ func TestMiddlewareLogging(t *testing.T) {
}
})
})
}
func TestAuthorizationUnauthorized(t *testing.T) {
// Missing Authorization header
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with MISSING Authorization header returns 401 - Unauthorized", func(t *testing.T) {
if resp.StatusCode != 401 {
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
}
})
t.Run("Protected resource with MISSING Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="missing_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with MISSING Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") {
t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String())
}
})
})
// Authorization header without Bearer prefix
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with INCOMPATIBLE Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="missing_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with INCOMPATIBLE Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") {
t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String())
}
})
})
// Invalid Authorization header
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer invalid_base64"+tokenBasicNotExpired)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with INVALID Authorization header returns 401 - Unauthorized", func(t *testing.T) {
if resp.StatusCode != 401 {
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
}
})
t.Run("Protected resource with INVALID Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with INVALID Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
!strings.Contains(ctx.LogBuffer.String(), "error: failed to parse JWT token: illegal base64 data") {
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
}
})
})
// Expired Authorization Bearer token
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+tokenBasicExpired)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with EXPIRED Authorization header returns 401 - Unauthorized", func(t *testing.T) {
if resp.StatusCode != 401 {
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
}
})
t.Run("Protected resource with EXPIRED Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with EXPIRED Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
!strings.Contains(ctx.LogBuffer.String(), "validation failed, token is expired (exp)") {
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
}
})
})
// Failed OIDC validation
key, oidcProvider, httpServer := NewOidcTestServer(t)
t.Cleanup(httpServer.Close)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with INVALID OIDC Authorization header returns 401 - Unauthorized", func(t *testing.T) {
if resp.StatusCode != 401 {
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
}
})
t.Run("Protected resource with INVALID OIDC Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with INVALID OIDC Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - OIDC token validation error") &&
!strings.Contains(ctx.LogBuffer.String(), "JWT token verification failed: oidc: id token issued by a different provider") {
t.Errorf("Expected log entry for OIDC validation error, got: %s", ctx.LogBuffer.String())
}
})
})
// Failed Kubernetes TokenReview
rawClaims := `{
"iss": "` + httpServer.URL + `",
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
"aud": "kubernetes-mcp-server"
}`
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+validOidcToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close })
t.Run("Protected resource with INVALID KUBERNETES Authorization header returns 401 - Unauthorized", func(t *testing.T) {
if resp.StatusCode != 401 {
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
}
})
t.Run("Protected resource with INVALID KUBERNETES Authorization header returns WWW-Authenticate header", func(t *testing.T) {
authHeader := resp.Header.Get("WWW-Authenticate")
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
if authHeader != expected {
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
}
})
t.Run("Protected resource with INVALID KUBERNETES Authorization header logs error", func(t *testing.T) {
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - API Server token validation error") {
t.Errorf("Expected log entry for Kubernetes TokenReview error, got: %s", ctx.LogBuffer.String())
}
})
})
}