mirror of
https://github.com/openshift/openshift-mcp-server.git
synced 2025-10-17 14:27:48 +03:00
test(auth): complete test suite for unauthorized scenarios (#220)
Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
@@ -41,9 +41,9 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
|
||||
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
|
||||
|
||||
if serverURL == "" {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience))
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="missing_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
|
||||
}
|
||||
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -103,7 +103,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
|
||||
// with the other token in the headers (TODO: still need to validate aud and exp of this token separately).
|
||||
_, _, err = mcpServer.VerifyTokenAPIServer(r.Context(), token, audience)
|
||||
if err != nil {
|
||||
klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
klog.V(1).Infof("Authentication failed - API Server token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
if serverURL == "" {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -220,103 +218,3 @@ func TestJWTClaimsGetScopes(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthorizationMiddleware(t *testing.T) {
|
||||
// Create a mock handler
|
||||
handlerCalled := false
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("OAuth disabled - passes through", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth disabled
|
||||
middleware := AuthorizationMiddleware(false, "", nil, nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request without authorization header
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("expected handler to be called when OAuth is disabled")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("healthz endpoint - passes through", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth enabled
|
||||
middleware := AuthorizationMiddleware(true, "", nil, nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request to healthz endpoint
|
||||
req := httptest.NewRequest("GET", "/healthz", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("expected handler to be called for healthz endpoint")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OAuth enabled - missing token", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth enabled
|
||||
middleware := AuthorizationMiddleware(true, "", nil, nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request without authorization header
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if handlerCalled {
|
||||
t.Error("expected handler NOT to be called when token is missing")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "Bearer token required") {
|
||||
t.Errorf("expected bearer token error message, got %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OAuth enabled - invalid token format", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth enabled
|
||||
middleware := AuthorizationMiddleware(true, "", nil, nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request with invalid bearer token
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if handlerCalled {
|
||||
t.Error("expected handler NOT to be called when token is invalid")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "Invalid token") {
|
||||
t.Errorf("expected invalid token error message, got %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,17 +4,23 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/coreos/go-oidc/v3/oidc/oidctest"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"k8s.io/client-go/tools/clientcmd"
|
||||
"k8s.io/client-go/tools/clientcmd/api"
|
||||
@@ -26,17 +32,22 @@ import (
|
||||
)
|
||||
|
||||
type httpContext struct {
|
||||
t *testing.T
|
||||
klogState klog.State
|
||||
logBuffer bytes.Buffer
|
||||
httpAddress string // HTTP server address
|
||||
LogBuffer bytes.Buffer
|
||||
HttpAddress string // HTTP server address
|
||||
timeoutCancel context.CancelFunc // Release resources if test completes before the timeout
|
||||
stopServer context.CancelFunc
|
||||
waitForShutdown func() error
|
||||
StopServer context.CancelFunc
|
||||
WaitForShutdown func() error
|
||||
StaticConfig *config.StaticConfig
|
||||
OidcProvider *oidc.Provider
|
||||
}
|
||||
|
||||
func (c *httpContext) beforeEach() {
|
||||
func (c *httpContext) beforeEach(t *testing.T) {
|
||||
t.Helper()
|
||||
http.DefaultClient.Timeout = 10 * time.Second
|
||||
if c.StaticConfig == nil {
|
||||
c.StaticConfig = &config.StaticConfig{}
|
||||
}
|
||||
// Fake Kubernetes configuration
|
||||
fakeConfig := api.NewConfig()
|
||||
fakeConfig.Clusters["fake"] = api.NewCluster()
|
||||
@@ -44,7 +55,7 @@ func (c *httpContext) beforeEach() {
|
||||
fakeConfig.Contexts["fake-context"] = api.NewContext()
|
||||
fakeConfig.Contexts["fake-context"].Cluster = "fake"
|
||||
fakeConfig.CurrentContext = "fake-context"
|
||||
kubeConfig := filepath.Join(c.t.TempDir(), "config")
|
||||
kubeConfig := filepath.Join(t.TempDir(), "config")
|
||||
_ = clientcmd.WriteToFile(*fakeConfig, kubeConfig)
|
||||
_ = os.Setenv("KUBECONFIG", kubeConfig)
|
||||
// Capture logging
|
||||
@@ -52,33 +63,33 @@ func (c *httpContext) beforeEach() {
|
||||
flags := flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
klog.InitFlags(flags)
|
||||
_ = flags.Set("v", "5")
|
||||
klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(5), textlogger.Output(&c.logBuffer))))
|
||||
klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(5), textlogger.Output(&c.LogBuffer))))
|
||||
// Start server in random port
|
||||
ln, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
c.t.Fatalf("Failed to find random port for HTTP server: %v", err)
|
||||
t.Fatalf("Failed to find random port for HTTP server: %v", err)
|
||||
}
|
||||
c.httpAddress = ln.Addr().String()
|
||||
c.HttpAddress = ln.Addr().String()
|
||||
if randomPortErr := ln.Close(); randomPortErr != nil {
|
||||
c.t.Fatalf("Failed to close random port listener: %v", randomPortErr)
|
||||
t.Fatalf("Failed to close random port listener: %v", randomPortErr)
|
||||
}
|
||||
staticConfig := &config.StaticConfig{Port: fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port)}
|
||||
c.StaticConfig.Port = fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port)
|
||||
mcpServer, err := mcp.NewServer(mcp.Configuration{
|
||||
Profile: mcp.Profiles[0],
|
||||
StaticConfig: staticConfig,
|
||||
StaticConfig: c.StaticConfig,
|
||||
})
|
||||
if err != nil {
|
||||
c.t.Fatalf("Failed to create MCP server: %v", err)
|
||||
t.Fatalf("Failed to create MCP server: %v", err)
|
||||
}
|
||||
var timeoutCtx, cancelCtx context.Context
|
||||
timeoutCtx, c.timeoutCancel = context.WithTimeout(c.t.Context(), 10*time.Second)
|
||||
timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second)
|
||||
group, gc := errgroup.WithContext(timeoutCtx)
|
||||
cancelCtx, c.stopServer = context.WithCancel(gc)
|
||||
group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig, nil) })
|
||||
c.waitForShutdown = group.Wait
|
||||
cancelCtx, c.StopServer = context.WithCancel(gc)
|
||||
group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider) })
|
||||
c.WaitForShutdown = group.Wait
|
||||
// Wait for HTTP server to start (using net)
|
||||
for i := 0; i < 10; i++ {
|
||||
conn, err := net.Dial("tcp", c.httpAddress)
|
||||
conn, err := net.Dial("tcp", c.HttpAddress)
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
break
|
||||
@@ -87,11 +98,12 @@ func (c *httpContext) beforeEach() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *httpContext) afterEach() {
|
||||
c.stopServer()
|
||||
err := c.waitForShutdown()
|
||||
func (c *httpContext) afterEach(t *testing.T) {
|
||||
t.Helper()
|
||||
c.StopServer()
|
||||
err := c.WaitForShutdown()
|
||||
if err != nil {
|
||||
c.t.Errorf("HTTP server did not shut down gracefully: %v", err)
|
||||
t.Errorf("HTTP server did not shut down gracefully: %v", err)
|
||||
}
|
||||
c.timeoutCancel()
|
||||
c.klogState.Restore()
|
||||
@@ -99,34 +111,61 @@ func (c *httpContext) afterEach() {
|
||||
}
|
||||
|
||||
func testCase(t *testing.T, test func(c *httpContext)) {
|
||||
ctx := &httpContext{t: t}
|
||||
ctx.beforeEach()
|
||||
t.Cleanup(ctx.afterEach)
|
||||
test(ctx)
|
||||
testCaseWithContext(t, &httpContext{}, test)
|
||||
}
|
||||
|
||||
func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpContext)) {
|
||||
httpCtx.beforeEach(t)
|
||||
t.Cleanup(func() { httpCtx.afterEach(t) })
|
||||
test(httpCtx)
|
||||
}
|
||||
|
||||
func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *oidc.Provider, httpServer *httptest.Server) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key for oidc: %v", err)
|
||||
}
|
||||
oidcServer := &oidctest.Server{
|
||||
PublicKeys: []oidctest.PublicKey{
|
||||
{
|
||||
PublicKey: privateKey.Public(),
|
||||
KeyID: "test-oidc-key-id",
|
||||
Algorithm: oidc.RS256,
|
||||
},
|
||||
},
|
||||
}
|
||||
httpServer = httptest.NewServer(oidcServer)
|
||||
oidcServer.SetIssuer(httpServer.URL)
|
||||
oidcProvider, err = oidc.NewProvider(t.Context(), httpServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create OIDC provider: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
testCase(t, func(ctx *httpContext) {
|
||||
ctx.stopServer()
|
||||
err := ctx.waitForShutdown()
|
||||
ctx.StopServer()
|
||||
err := ctx.WaitForShutdown()
|
||||
t.Run("Stops gracefully", func(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("Expected graceful shutdown, but got error: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("Stops on context cancel", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.logBuffer.String(), "Context cancelled, initiating graceful shutdown") {
|
||||
t.Errorf("Context cancelled, initiating graceful shutdown, got: %s", ctx.logBuffer.String())
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Context cancelled, initiating graceful shutdown") {
|
||||
t.Errorf("Context cancelled, initiating graceful shutdown, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
t.Run("Starts server shutdown", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.logBuffer.String(), "Shutting down HTTP server gracefully") {
|
||||
t.Errorf("Expected graceful shutdown log, got: %s", ctx.logBuffer.String())
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Shutting down HTTP server gracefully") {
|
||||
t.Errorf("Expected graceful shutdown log, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
t.Run("Server shutdown completes", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.logBuffer.String(), "HTTP server shutdown complete") {
|
||||
t.Errorf("Expected HTTP server shutdown completed log, got: %s", ctx.logBuffer.String())
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "HTTP server shutdown complete") {
|
||||
t.Errorf("Expected HTTP server shutdown completed log, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -134,7 +173,7 @@ func TestGracefulShutdown(t *testing.T) {
|
||||
|
||||
func TestSseTransport(t *testing.T) {
|
||||
testCase(t, func(ctx *httpContext) {
|
||||
sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.httpAddress))
|
||||
sseResp, sseErr := http.Get(fmt.Sprintf("http://%s/sse", ctx.HttpAddress))
|
||||
t.Cleanup(func() { _ = sseResp.Body.Close() })
|
||||
t.Run("Exposes SSE endpoint at /sse", func(t *testing.T) {
|
||||
if sseErr != nil {
|
||||
@@ -167,7 +206,7 @@ func TestSseTransport(t *testing.T) {
|
||||
}
|
||||
})
|
||||
messageResp, messageErr := http.Post(
|
||||
fmt.Sprintf("http://%s/message?sessionId=%s", ctx.httpAddress, strings.TrimSpace(endpoint[25:])),
|
||||
fmt.Sprintf("http://%s/message?sessionId=%s", ctx.HttpAddress, strings.TrimSpace(endpoint[25:])),
|
||||
"application/json",
|
||||
bytes.NewBufferString("{}"),
|
||||
)
|
||||
@@ -185,7 +224,7 @@ func TestSseTransport(t *testing.T) {
|
||||
|
||||
func TestStreamableHttpTransport(t *testing.T) {
|
||||
testCase(t, func(ctx *httpContext) {
|
||||
mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.httpAddress))
|
||||
mcpGetResp, mcpGetErr := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
|
||||
t.Cleanup(func() { _ = mcpGetResp.Body.Close() })
|
||||
t.Run("Exposes MCP GET endpoint at /mcp", func(t *testing.T) {
|
||||
if mcpGetErr != nil {
|
||||
@@ -200,7 +239,7 @@ func TestStreamableHttpTransport(t *testing.T) {
|
||||
t.Errorf("Expected Content-Type text/event-stream (GET), got %s", mcpGetResp.Header.Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.httpAddress), "application/json", bytes.NewBufferString("{}"))
|
||||
mcpPostResp, mcpPostErr := http.Post(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), "application/json", bytes.NewBufferString("{}"))
|
||||
t.Cleanup(func() { _ = mcpPostResp.Body.Close() })
|
||||
t.Run("Exposes MCP POST endpoint at /mcp", func(t *testing.T) {
|
||||
if mcpPostErr != nil {
|
||||
@@ -221,7 +260,7 @@ func TestStreamableHttpTransport(t *testing.T) {
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
testCase(t, func(ctx *httpContext) {
|
||||
t.Run("Exposes health check endpoint at /healthz", func(t *testing.T) {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.httpAddress))
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get health check endpoint: %v", err)
|
||||
}
|
||||
@@ -231,11 +270,24 @@ func TestHealthCheck(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
// Health exposed even when require Authorization
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get health check endpoint with OAuth: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
t.Run("Health check with OAuth returns HTTP 200 OK", func(t *testing.T) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestWellKnownOAuthProtectedResource(t *testing.T) {
|
||||
testCase(t, func(ctx *httpContext) {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.httpAddress))
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress))
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
t.Run("Exposes .well-known/oauth-protected-resource endpoint", func(t *testing.T) {
|
||||
if err != nil {
|
||||
@@ -255,17 +307,17 @@ func TestWellKnownOAuthProtectedResource(t *testing.T) {
|
||||
|
||||
func TestMiddlewareLogging(t *testing.T) {
|
||||
testCase(t, func(ctx *httpContext) {
|
||||
_, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.httpAddress))
|
||||
_, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress))
|
||||
t.Run("Logs HTTP requests and responses", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.logBuffer.String(), "GET /.well-known/oauth-protected-resource 200") {
|
||||
t.Errorf("Expected log entry for GET /.well-known/oauth-protected-resource, got: %s", ctx.logBuffer.String())
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "GET /.well-known/oauth-protected-resource 200") {
|
||||
t.Errorf("Expected log entry for GET /.well-known/oauth-protected-resource, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
t.Run("Logs HTTP request duration", func(t *testing.T) {
|
||||
expected := `"GET /.well-known/oauth-protected-resource 200 (.+)"`
|
||||
m := regexp.MustCompile(expected).FindStringSubmatch(ctx.logBuffer.String())
|
||||
m := regexp.MustCompile(expected).FindStringSubmatch(ctx.LogBuffer.String())
|
||||
if len(m) != 2 {
|
||||
t.Fatalf("Expected log entry to contain duration, got %s", ctx.logBuffer.String())
|
||||
t.Fatalf("Expected log entry to contain duration, got %s", ctx.LogBuffer.String())
|
||||
}
|
||||
duration, err := time.ParseDuration(m[1])
|
||||
if err != nil {
|
||||
@@ -276,5 +328,188 @@ func TestMiddlewareLogging(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestAuthorizationUnauthorized(t *testing.T) {
|
||||
// Missing Authorization header
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get protected endpoint: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close })
|
||||
t.Run("Protected resource with MISSING Authorization header returns 401 - Unauthorized", func(t *testing.T) {
|
||||
if resp.StatusCode != 401 {
|
||||
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with MISSING Authorization header returns WWW-Authenticate header", func(t *testing.T) {
|
||||
authHeader := resp.Header.Get("WWW-Authenticate")
|
||||
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="missing_token"`
|
||||
if authHeader != expected {
|
||||
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with MISSING Authorization header logs error", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") {
|
||||
t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
// Authorization header without Bearer prefix
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get protected endpoint: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close })
|
||||
t.Run("Protected resource with INCOMPATIBLE Authorization header returns WWW-Authenticate header", func(t *testing.T) {
|
||||
authHeader := resp.Header.Get("WWW-Authenticate")
|
||||
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="missing_token"`
|
||||
if authHeader != expected {
|
||||
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INCOMPATIBLE Authorization header logs error", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - missing or invalid bearer token") {
|
||||
t.Errorf("Expected log entry for missing or invalid bearer token, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
// Invalid Authorization header
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer invalid_base64"+tokenBasicNotExpired)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get protected endpoint: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close })
|
||||
t.Run("Protected resource with INVALID Authorization header returns 401 - Unauthorized", func(t *testing.T) {
|
||||
if resp.StatusCode != 401 {
|
||||
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INVALID Authorization header returns WWW-Authenticate header", func(t *testing.T) {
|
||||
authHeader := resp.Header.Get("WWW-Authenticate")
|
||||
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
|
||||
if authHeader != expected {
|
||||
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INVALID Authorization header logs error", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
|
||||
!strings.Contains(ctx.LogBuffer.String(), "error: failed to parse JWT token: illegal base64 data") {
|
||||
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
// Expired Authorization Bearer token
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tokenBasicExpired)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get protected endpoint: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close })
|
||||
t.Run("Protected resource with EXPIRED Authorization header returns 401 - Unauthorized", func(t *testing.T) {
|
||||
if resp.StatusCode != 401 {
|
||||
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with EXPIRED Authorization header returns WWW-Authenticate header", func(t *testing.T) {
|
||||
authHeader := resp.Header.Get("WWW-Authenticate")
|
||||
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
|
||||
if authHeader != expected {
|
||||
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with EXPIRED Authorization header logs error", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
|
||||
!strings.Contains(ctx.LogBuffer.String(), "validation failed, token is expired (exp)") {
|
||||
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
// Failed OIDC validation
|
||||
key, oidcProvider, httpServer := NewOidcTestServer(t)
|
||||
t.Cleanup(httpServer.Close)
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get protected endpoint: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close })
|
||||
t.Run("Protected resource with INVALID OIDC Authorization header returns 401 - Unauthorized", func(t *testing.T) {
|
||||
if resp.StatusCode != 401 {
|
||||
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INVALID OIDC Authorization header returns WWW-Authenticate header", func(t *testing.T) {
|
||||
authHeader := resp.Header.Get("WWW-Authenticate")
|
||||
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
|
||||
if authHeader != expected {
|
||||
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INVALID OIDC Authorization header logs error", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - OIDC token validation error") &&
|
||||
!strings.Contains(ctx.LogBuffer.String(), "JWT token verification failed: oidc: id token issued by a different provider") {
|
||||
t.Errorf("Expected log entry for OIDC validation error, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
// Failed Kubernetes TokenReview
|
||||
rawClaims := `{
|
||||
"iss": "` + httpServer.URL + `",
|
||||
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
|
||||
"aud": "kubernetes-mcp-server"
|
||||
}`
|
||||
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+validOidcToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get protected endpoint: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close })
|
||||
t.Run("Protected resource with INVALID KUBERNETES Authorization header returns 401 - Unauthorized", func(t *testing.T) {
|
||||
if resp.StatusCode != 401 {
|
||||
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INVALID KUBERNETES Authorization header returns WWW-Authenticate header", func(t *testing.T) {
|
||||
authHeader := resp.Header.Get("WWW-Authenticate")
|
||||
expected := `Bearer realm="Kubernetes MCP Server", audience="kubernetes-mcp-server", error="invalid_token"`
|
||||
if authHeader != expected {
|
||||
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
|
||||
}
|
||||
})
|
||||
t.Run("Protected resource with INVALID KUBERNETES Authorization header logs error", func(t *testing.T) {
|
||||
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - API Server token validation error") {
|
||||
t.Errorf("Expected log entry for Kubernetes TokenReview error, got: %s", ctx.LogBuffer.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user