diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 5ca0edc..b604416 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -41,9 +41,9 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider * klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr) if serverURL == "" { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience)) + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience)) } else { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint)) + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="missing_token"`, audience, serverURL, oauthProtectedResourceEndpoint)) } http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized) return @@ -103,7 +103,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider * // with the other token in the headers (TODO: still need to validate aud and exp of this token separately). _, _, err = mcpServer.VerifyTokenAPIServer(r.Context(), token, audience) if err != nil { - klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) + klog.V(1).Infof("Authentication failed - API Server token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) if serverURL == "" { w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience)) diff --git a/pkg/http/authorization_test.go b/pkg/http/authorization_test.go index 7849483..31ad804 100644 --- a/pkg/http/authorization_test.go +++ b/pkg/http/authorization_test.go @@ -1,8 +1,6 @@ package http import ( - "net/http" - "net/http/httptest" "strings" "testing" @@ -220,103 +218,3 @@ func TestJWTClaimsGetScopes(t *testing.T) { } }) } - -func TestAuthorizationMiddleware(t *testing.T) { - // Create a mock handler - handlerCalled := false - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handlerCalled = true - w.WriteHeader(http.StatusOK) - }) - - t.Run("OAuth disabled - passes through", func(t *testing.T) { - handlerCalled = false - - // Create middleware with OAuth disabled - middleware := AuthorizationMiddleware(false, "", nil, nil) - wrappedHandler := middleware(handler) - - // Create request without authorization header - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(w, req) - - if !handlerCalled { - t.Error("expected handler to be called when OAuth is disabled") - } - if w.Code != http.StatusOK { - t.Errorf("expected status 200, got %d", w.Code) - } - }) - - t.Run("healthz endpoint - passes through", func(t *testing.T) { - handlerCalled = false - - // Create middleware with OAuth enabled - middleware := AuthorizationMiddleware(true, "", nil, nil) - wrappedHandler := middleware(handler) - - // Create request to healthz endpoint - req := httptest.NewRequest("GET", "/healthz", nil) - w := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(w, req) - - if !handlerCalled { - t.Error("expected handler to be called for healthz endpoint") - } - if w.Code != http.StatusOK { - t.Errorf("expected status 200, got %d", w.Code) - } - }) - - t.Run("OAuth enabled - missing token", func(t *testing.T) { - handlerCalled = false - - // Create middleware with OAuth enabled - middleware := AuthorizationMiddleware(true, "", nil, nil) - wrappedHandler := middleware(handler) - - // Create request without authorization header - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(w, req) - - if handlerCalled { - t.Error("expected handler NOT to be called when token is missing") - } - if w.Code != http.StatusUnauthorized { - t.Errorf("expected status 401, got %d", w.Code) - } - if !strings.Contains(w.Body.String(), "Bearer token required") { - t.Errorf("expected bearer token error message, got %s", w.Body.String()) - } - }) - - t.Run("OAuth enabled - invalid token format", func(t *testing.T) { - handlerCalled = false - - // Create middleware with OAuth enabled - middleware := AuthorizationMiddleware(true, "", nil, nil) - wrappedHandler := middleware(handler) - - // Create request with invalid bearer token - req := httptest.NewRequest("GET", "/test", nil) - req.Header.Set("Authorization", "Bearer invalid-token") - w := httptest.NewRecorder() - - wrappedHandler.ServeHTTP(w, req) - - if handlerCalled { - t.Error("expected handler NOT to be called when token is invalid") - } - if w.Code != http.StatusUnauthorized { - t.Errorf("expected status 401, got %d", w.Code) - } - if !strings.Contains(w.Body.String(), "Invalid token") { - t.Errorf("expected invalid token error message, got %s", w.Body.String()) - } - }) -} diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 7c530ad..89e091f 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -4,17 +4,23 @@ import ( "bufio" "bytes" "context" + "crypto/rand" + "crypto/rsa" "flag" "fmt" "net" "net/http" + "net/http/httptest" "os" "path/filepath" "regexp" + "strconv" "strings" "testing" "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/coreos/go-oidc/v3/oidc/oidctest" "golang.org/x/sync/errgroup" "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/clientcmd/api" @@ -26,17 +32,22 @@ import ( ) type httpContext struct { - t *testing.T klogState klog.State - logBuffer bytes.Buffer - httpAddress string // HTTP server address + LogBuffer bytes.Buffer + HttpAddress string // HTTP server address timeoutCancel context.CancelFunc // Release resources if test completes before the timeout - stopServer context.CancelFunc - waitForShutdown func() error + StopServer context.CancelFunc + WaitForShutdown func() error + StaticConfig *config.StaticConfig + OidcProvider *oidc.Provider } -func (c *httpContext) beforeEach() { +func (c *httpContext) beforeEach(t *testing.T) { + t.Helper() http.DefaultClient.Timeout = 10 * time.Second + if c.StaticConfig == nil { + c.StaticConfig = &config.StaticConfig{} + } // Fake Kubernetes configuration fakeConfig := api.NewConfig() fakeConfig.Clusters["fake"] = api.NewCluster() @@ -44,7 +55,7 @@ func (c *httpContext) beforeEach() { fakeConfig.Contexts["fake-context"] = api.NewContext() fakeConfig.Contexts["fake-context"].Cluster = "fake" fakeConfig.CurrentContext = "fake-context" - kubeConfig := filepath.Join(c.t.TempDir(), "config") + kubeConfig := filepath.Join(t.TempDir(), "config") _ = clientcmd.WriteToFile(*fakeConfig, kubeConfig) _ = os.Setenv("KUBECONFIG", kubeConfig) // Capture logging @@ -52,33 +63,33 @@ func (c *httpContext) beforeEach() { flags := flag.NewFlagSet("test", flag.ContinueOnError) klog.InitFlags(flags) _ = flags.Set("v", "5") - klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(5), textlogger.Output(&c.logBuffer)))) + klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(5), textlogger.Output(&c.LogBuffer)))) // Start server in random port ln, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { - c.t.Fatalf("Failed to find random port for HTTP server: %v", err) + t.Fatalf("Failed to find random port for HTTP server: %v", err) } - c.httpAddress = ln.Addr().String() + c.HttpAddress = ln.Addr().String() if randomPortErr := ln.Close(); randomPortErr != nil { - c.t.Fatalf("Failed to close random port listener: %v", randomPortErr) + t.Fatalf("Failed to close random port listener: %v", randomPortErr) } - staticConfig := &config.StaticConfig{Port: fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port)} + c.StaticConfig.Port = fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port) mcpServer, err := mcp.NewServer(mcp.Configuration{ Profile: mcp.Profiles[0], - StaticConfig: staticConfig, + StaticConfig: c.StaticConfig, }) if err != nil { - c.t.Fatalf("Failed to create MCP server: %v", err) + t.Fatalf("Failed to create MCP server: %v", err) } var timeoutCtx, cancelCtx context.Context - timeoutCtx, c.timeoutCancel = context.WithTimeout(c.t.Context(), 10*time.Second) + timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second) group, gc := errgroup.WithContext(timeoutCtx) - cancelCtx, c.stopServer = context.WithCancel(gc) - group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig, nil) }) - c.waitForShutdown = group.Wait + cancelCtx, c.StopServer = context.WithCancel(gc) + group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider) }) + c.WaitForShutdown = group.Wait // Wait for HTTP server to start (using net) for i := 0; i < 10; i++ { - conn, err := net.Dial("tcp", c.httpAddress) + conn, err := net.Dial("tcp", c.HttpAddress) if err == nil { _ = conn.Close() break @@ -87,11 +98,12 @@ func (c *httpContext) beforeEach() { } } -func (c *httpContext) afterEach() { - c.stopServer() - err := c.waitForShutdown() +func (c *httpContext) afterEach(t *testing.T) { + t.Helper() + c.StopServer() + err := c.WaitForShutdown() if err != nil { - c.t.Errorf("HTTP server did not shut down gracefully: %v", err) + t.Errorf("HTTP server did not shut down gracefully: %v", err) } c.timeoutCancel() c.klogState.Restore() @@ -99,34 +111,61 @@ func (c *httpContext) afterEach() { } func testCase(t *testing.T, test func(c *httpContext)) { - ctx := &httpContext{t: t} - ctx.beforeEach() - t.Cleanup(ctx.afterEach) - test(ctx) + testCaseWithContext(t, &httpContext{}, test) +} + +func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpContext)) { + httpCtx.beforeEach(t) + t.Cleanup(func() { httpCtx.afterEach(t) }) + test(httpCtx) +} + +func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *oidc.Provider, httpServer *httptest.Server) { + t.Helper() + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate private key for oidc: %v", err) + } + oidcServer := &oidctest.Server{ + PublicKeys: []oidctest.PublicKey{ + { + PublicKey: privateKey.Public(), + KeyID: "test-oidc-key-id", + Algorithm: oidc.RS256, + }, + }, + } + httpServer = httptest.NewServer(oidcServer) + oidcServer.SetIssuer(httpServer.URL) + oidcProvider, err = oidc.NewProvider(t.Context(), httpServer.URL) + if err != nil { + t.Fatalf("failed to create OIDC provider: %v", err) + } + return } func TestGracefulShutdown(t *testing.T) { testCase(t, func(ctx *httpContext) { - ctx.stopServer() - err := ctx.waitForShutdown() + ctx.StopServer() + err := ctx.WaitForShutdown() t.Run("Stops gracefully", func(t *testing.T) { if err != nil { t.Errorf("Expected graceful shutdown, but got error: %v", err) } }) t.Run("Stops on context cancel", func(t *testing.T) { - if !strings.Contains(ctx.logBuffer.String(), "Context cancelled, initiating graceful shutdown") { - t.Errorf("Context cancelled, initiating graceful shutdown, got: %s", ctx.logBuffer.String()) + if !strings.Contains(ctx.LogBuffer.String(), "Context cancelled, initiating graceful shutdown") { + t.Errorf("Context cancelled, initiating graceful shutdown, got: %s", ctx.LogBuffer.String()) } }) t.Run("Starts server shutdown", func(t *testing.T) { - if !strings.Contains(ctx.logBuffer.String(), "Shutting down HTTP server gracefully") { - t.Errorf("Expected graceful shutdown log, got: %s", ctx.logBuffer.String()) + if !strings.Contains(ctx.LogBuffer.String(), "Shutting down HTTP server gracefully") { + t.Errorf("Expected graceful shutdown log, got: %s", ctx.LogBuffer.String()) } }) t.Run("Server shutdown completes", func(t *testing.T) { - if !strings.Contains(ctx.logBuffer.String(), "HTTP server shutdown complete") { - t.Errorf("Expected HTTP server shutdown completed log, got: %s", ctx.logBuffer.String()) + if !strings.Contains(ctx.LogBuffer.String(), "HTTP server shutdown complete") { + t.Errorf("Expected HTTP server shutdown completed log, got: %s", ctx.LogBuffer.String()) } }) }) @@ -134,7 +173,7 @@ func TestGracefulShutdown(t *testing.T) { func TestSseTransport(t *testing.T) { testCase(t, func(ctx *httpContext) { - sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.httpAddress)) + sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.HttpAddress)) t.Cleanup(func() { _ = sseResp.Body.Close() }) t.Run("Exposes SSE endpoint at /sse", func(t *testing.T) { if sseErr != nil { @@ -167,7 +206,7 @@ func TestSseTransport(t *testing.T) { } }) messageResp, messageErr := http.Post( - fmt.Sprintf("http://%s/message?sessionId=%s", ctx.httpAddress, strings.TrimSpace(endpoint[25:])), + fmt.Sprintf("http://%s/message?sessionId=%s", ctx.HttpAddress, strings.TrimSpace(endpoint[25:])), "application/json", bytes.NewBufferString("{}"), ) @@ -185,7 +224,7 @@ func TestSseTransport(t *testing.T) { func TestStreamableHttpTransport(t *testing.T) { testCase(t, func(ctx *httpContext) { - mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.httpAddress)) + mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress)) t.Cleanup(func() { _ = mcpGetResp.Body.Close() }) t.Run("Exposes MCP GET endpoint at /mcp", func(t *testing.T) { if mcpGetErr != nil { @@ -200,7 +239,7 @@ func TestStreamableHttpTransport(t *testing.T) { t.Errorf("Expected Content-Type text/event-stream (GET), got %s", mcpGetResp.Header.Get("Content-Type")) } }) - mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.httpAddress), "application/json", bytes.NewBufferString("{}")) + mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), "application/json", bytes.NewBufferString("{}")) t.Cleanup(func() { _ = mcpPostResp.Body.Close() }) t.Run("Exposes MCP POST endpoint at /mcp", func(t *testing.T) { if mcpPostErr != nil { @@ -221,7 +260,7 @@ func TestStreamableHttpTransport(t *testing.T) { func TestHealthCheck(t *testing.T) { testCase(t, func(ctx *httpContext) { t.Run("Exposes health check endpoint at /healthz", func(t *testing.T) { - resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.httpAddress)) + resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress)) if err != nil { t.Fatalf("Failed to get health check endpoint: %v", err) } @@ -231,11 +270,24 @@ func TestHealthCheck(t *testing.T) { } }) }) + // Health exposed even when require Authorization + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) { + resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress)) + if err != nil { + t.Fatalf("Failed to get health check endpoint with OAuth: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close() }) + t.Run("Health check with OAuth returns HTTP 200 OK", func(t *testing.T) { + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) + } + }) + }) } func TestWellKnownOAuthProtectedResource(t *testing.T) { testCase(t, func(ctx *httpContext) { - resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.httpAddress)) + resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress)) t.Cleanup(func() { _ = resp.Body.Close() }) t.Run("Exposes .well-known/oauth-protected-resource endpoint", func(t *testing.T) { if err != nil { @@ -255,17 +307,17 @@ func TestWellKnownOAuthProtectedResource(t *testing.T) { func TestMiddlewareLogging(t *testing.T) { testCase(t, func(ctx *httpContext) { - _, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.httpAddress)) + _, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress)) t.Run("Logs HTTP requests and responses", func(t *testing.T) { - if !strings.Contains(ctx.logBuffer.String(), "GET /.well-known/oauth-protected-resource 200") { - t.Errorf("Expected log entry for GET /.well-known/oauth-protected-resource, got: %s", ctx.logBuffer.String()) + if !strings.Contains(ctx.LogBuffer.String(), "GET /.well-known/oauth-protected-resource 200") { + t.Errorf("Expected log entry for GET /.well-known/oauth-protected-resource, got: %s", ctx.LogBuffer.String()) } }) t.Run("Logs HTTP request duration", func(t *testing.T) { expected := `"GET /.well-known/oauth-protected-resource 200 (.+)"` - m := regexp.MustCompile(expected).FindStringSubmatch(ctx.logBuffer.String()) + m := regexp.MustCompile(expected).FindStringSubmatch(ctx.LogBuffer.String()) if len(m) != 2 { - t.Fatalf("Expected log entry to contain duration, got %s", ctx.logBuffer.String()) + t.Fatalf("Expected log entry to contain duration, got %s", ctx.LogBuffer.String()) } duration, err := time.ParseDuration(m[1]) if err != nil { @@ -276,5 +328,188 @@ func TestMiddlewareLogging(t *testing.T) { } }) }) - +} + +func TestAuthorizationUnauthorized(t *testing.T) { + // Missing Authorization header + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) { + resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress)) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close }) + t.Run("Protected resource with MISSING Authorization header returns 401 - Unauthorized", func(t *testing.T) { + if resp.StatusCode != 401 { + t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) + } + }) + t.Run("Protected resource with MISSING Authorization header returns WWW-Authenticate header", func(t *testing.T) { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="missing_token"` + if authHeader != expected { + t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) + } + }) + t.Run("Protected resource with MISSING Authorization header logs error", func(t *testing.T) { + if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") { + t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String()) + } + }) + }) + // Authorization header without Bearer prefix + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) { + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close }) + t.Run("Protected resource with INCOMPATIBLE Authorization header returns WWW-Authenticate header", func(t *testing.T) { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="missing_token"` + if authHeader != expected { + t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) + } + }) + t.Run("Protected resource with INCOMPATIBLE Authorization header logs error", func(t *testing.T) { + if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") { + t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String()) + } + }) + }) + // Invalid Authorization header + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) { + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer invalid_base64"+tokenBasicNotExpired) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close }) + t.Run("Protected resource with INVALID Authorization header returns 401 - Unauthorized", func(t *testing.T) { + if resp.StatusCode != 401 { + t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) + } + }) + t.Run("Protected resource with INVALID Authorization header returns WWW-Authenticate header", func(t *testing.T) { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"` + if authHeader != expected { + t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) + } + }) + t.Run("Protected resource with INVALID Authorization header logs error", func(t *testing.T) { + if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") && + !strings.Contains(ctx.LogBuffer.String(), "error: failed to parse JWT token: illegal base64 data") { + t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String()) + } + }) + }) + // Expired Authorization Bearer token + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) { + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+tokenBasicExpired) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close }) + t.Run("Protected resource with EXPIRED Authorization header returns 401 - Unauthorized", func(t *testing.T) { + if resp.StatusCode != 401 { + t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) + } + }) + t.Run("Protected resource with EXPIRED Authorization header returns WWW-Authenticate header", func(t *testing.T) { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"` + if authHeader != expected { + t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) + } + }) + t.Run("Protected resource with EXPIRED Authorization header logs error", func(t *testing.T) { + if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") && + !strings.Contains(ctx.LogBuffer.String(), "validation failed, token is expired (exp)") { + t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String()) + } + }) + }) + // Failed OIDC validation + key, oidcProvider, httpServer := NewOidcTestServer(t) + t.Cleanup(httpServer.Close) + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) { + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close }) + t.Run("Protected resource with INVALID OIDC Authorization header returns 401 - Unauthorized", func(t *testing.T) { + if resp.StatusCode != 401 { + t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) + } + }) + t.Run("Protected resource with INVALID OIDC Authorization header returns WWW-Authenticate header", func(t *testing.T) { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"` + if authHeader != expected { + t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) + } + }) + t.Run("Protected resource with INVALID OIDC Authorization header logs error", func(t *testing.T) { + if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - OIDC token validation error") && + !strings.Contains(ctx.LogBuffer.String(), "JWT token verification failed: oidc: id token issued by a different provider") { + t.Errorf("Expected log entry for OIDC validation error, got: %s", ctx.LogBuffer.String()) + } + }) + }) + // Failed Kubernetes TokenReview + rawClaims := `{ + "iss": "` + httpServer.URL + `", + "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, + "aud": "kubernetes-mcp-server" + }` + validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims) + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) { + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+validOidcToken) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close }) + t.Run("Protected resource with INVALID KUBERNETES Authorization header returns 401 - Unauthorized", func(t *testing.T) { + if resp.StatusCode != 401 { + t.Errorf("Expected HTTP 401, got %d", resp.StatusCode) + } + }) + t.Run("Protected resource with INVALID KUBERNETES Authorization header returns WWW-Authenticate header", func(t *testing.T) { + authHeader := resp.Header.Get("WWW-Authenticate") + expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"` + if authHeader != expected { + t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader) + } + }) + t.Run("Protected resource with INVALID KUBERNETES Authorization header logs error", func(t *testing.T) { + if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - API Server token validation error") { + t.Errorf("Expected log entry for Kubernetes TokenReview error, got: %s", ctx.LogBuffer.String()) + } + }) + }) }