diff --git a/go.mod b/go.mod index 8332ea4..1087161 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/spf13/afero v1.14.0 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.7 + golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 helm.sh/helm/v3 v3.18.4 k8s.io/api v0.33.3 @@ -116,7 +117,6 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.40.0 // indirect golang.org/x/net v0.42.0 // indirect - golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sys v0.34.0 // indirect golang.org/x/term v0.33.0 // indirect golang.org/x/text v0.27.0 // indirect diff --git a/internal/test/mock_server.go b/internal/test/mock_server.go index 4426991..b5f9047 100644 --- a/internal/test/mock_server.go +++ b/internal/test/mock_server.go @@ -14,6 +14,7 @@ import ( "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/httpstream/spdy" "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd/api" ) type MockServer struct { @@ -56,6 +57,21 @@ func (m *MockServer) Config() *rest.Config { return m.config } +func (m *MockServer) KubeConfig() *api.Config { + fakeConfig := api.NewConfig() + fakeConfig.Clusters["fake"] = api.NewCluster() + fakeConfig.Clusters["fake"].Server = m.config.Host + fakeConfig.Clusters["fake"].CertificateAuthorityData = m.config.CAData + fakeConfig.AuthInfos["fake"] = api.NewAuthInfo() + fakeConfig.AuthInfos["fake"].ClientKeyData = m.config.KeyData + fakeConfig.AuthInfos["fake"].ClientCertificateData = m.config.CertData + fakeConfig.Contexts["fake-context"] = api.NewContext() + fakeConfig.Contexts["fake-context"].Cluster = "fake" + fakeConfig.Contexts["fake-context"].AuthInfo = "fake" + fakeConfig.CurrentContext = "fake-context" + return fakeConfig +} + func WriteObject(w http.ResponseWriter, obj runtime.Object) { w.Header().Set("Content-Type", runtime.ContentTypeJSON) if err := json.NewEncoder(w).Encode(obj); err != nil { diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 18d2659..8085d69 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -9,6 +9,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" + authenticationapiv1 "k8s.io/api/authentication/v1" "k8s.io/klog/v2" "k8s.io/utils/strings/slices" @@ -19,8 +20,37 @@ const ( Audience = "mcp-server" ) -// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API -func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *oidc.Provider, mcpServer *mcp.Server) func(http.Handler) http.Handler { +type KubernetesApiTokenVerifier interface { + // KubernetesApiVerifyToken TODO: clarify proper implementation + KubernetesApiVerifyToken(ctx context.Context, token, audience string) (*authenticationapiv1.UserInfo, []string, error) +} + +// AuthorizationMiddleware validates the OAuth flow for protected resources. +// +// The flow is skipped for unprotected resources, such as health checks and well-known endpoints. +// +// There are several auth scenarios +// +// 1. requireOAuth is false: +// +// - The OAuth flow is skipped, and the server is effectively unprotected. +// - The request is passed to the next handler without any validation. +// +// see TestAuthorizationRequireOAuthFalse +// +// 2. requireOAuth is set to true, server is protected: +// +// 2.1. Raw Token Validation (oidcProvider is nil): +// - The token is validated offline for basic sanity checks (audience and expiration). +// - The token is then used against the Kubernetes API Server for TokenReview. +// +// 2.2. OIDC Provider Validation (oidcProvider is not nil): +// - The token is validated offline for basic sanity checks (audience and expiration). +// - The token is then validated against the OIDC Provider. +// - The token is then used against the Kubernetes API Server for TokenReview. +// +// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx): +func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) { @@ -38,20 +68,13 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider * if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr) - if serverURL == "" { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience)) - } else { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="missing_token"`, audience, serverURL, oauthProtectedResourceEndpoint)) - } + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience)) http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized) return } token := strings.TrimPrefix(authHeader, "Bearer ") - // Validate the token offline for simple sanity check - // Because missing expected audience and expired tokens must be - // rejected already. claims, err := ParseJWTClaims(token) if err == nil && claims != nil { err = claims.Validate(r.Context(), audience, oidcProvider) @@ -59,11 +82,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider * if err != nil { klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) - if serverURL == "" { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience)) - } else { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint)) - } + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience)) http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized) return } @@ -85,15 +104,11 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider * // 2. b. If this is not the only token in the headers, the token in here is used // only for authentication and authorization. Therefore, we need to send TokenReview request // with the other token in the headers (TODO: still need to validate aud and exp of this token separately). - _, _, err = mcpServer.VerifyTokenAPIServer(r.Context(), token, audience) + _, _, err = verifier.KubernetesApiVerifyToken(r.Context(), token, audience) if err != nil { klog.V(1).Infof("Authentication failed - API Server token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) - if serverURL == "" { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience)) - } else { - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint)) - } + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience)) http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized) return } diff --git a/pkg/http/http.go b/pkg/http/http.go index 113032b..7d45bd2 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -28,7 +28,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat mux := http.NewServeMux() wrappedMux := RequestMiddleware( - AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, oidcProvider, mcpServer)(mux), + AuthorizationMiddleware(staticConfig.RequireOAuth, oidcProvider, mcpServer)(mux), ) httpServer := &http.Server{ diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 82c62a8..4801874 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -19,11 +19,11 @@ import ( "testing" "time" + "github.com/containers/kubernetes-mcp-server/internal/test" "github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc/oidctest" "golang.org/x/sync/errgroup" "k8s.io/client-go/tools/clientcmd" - "k8s.io/client-go/tools/clientcmd/api" "k8s.io/klog/v2" "k8s.io/klog/v2/textlogger" @@ -33,6 +33,7 @@ import ( type httpContext struct { klogState klog.State + mockServer *test.MockServer LogBuffer bytes.Buffer HttpAddress string // HTTP server address timeoutCancel context.CancelFunc // Release resources if test completes before the timeout @@ -42,21 +43,31 @@ type httpContext struct { OidcProvider *oidc.Provider } +const tokenReviewSuccessful = ` + { + "kind": "TokenReview", + "apiVersion": "authentication.k8s.io/v1", + "spec": {"token": "valid-token"}, + "status": { + "authenticated": true, + "user": { + "username": "test-user", + "groups": ["system:authenticated"] + } + } + }` + func (c *httpContext) beforeEach(t *testing.T) { t.Helper() http.DefaultClient.Timeout = 10 * time.Second if c.StaticConfig == nil { c.StaticConfig = &config.StaticConfig{} } + c.mockServer = test.NewMockServer() // Fake Kubernetes configuration - fakeConfig := api.NewConfig() - fakeConfig.Clusters["fake"] = api.NewCluster() - fakeConfig.Clusters["fake"].Server = "https://example.com" - fakeConfig.Contexts["fake-context"] = api.NewContext() - fakeConfig.Contexts["fake-context"].Cluster = "fake" - fakeConfig.CurrentContext = "fake-context" + mockKubeConfig := c.mockServer.KubeConfig() kubeConfig := filepath.Join(t.TempDir(), "config") - _ = clientcmd.WriteToFile(*fakeConfig, kubeConfig) + _ = clientcmd.WriteToFile(*mockKubeConfig, kubeConfig) _ = os.Setenv("KUBECONFIG", kubeConfig) // Capture logging c.klogState = klog.CaptureState() @@ -100,6 +111,7 @@ func (c *httpContext) beforeEach(t *testing.T) { func (c *httpContext) afterEach(t *testing.T) { t.Helper() + c.mockServer.Close() c.StopServer() err := c.WaitForShutdown() if err != nil { @@ -546,3 +558,81 @@ func TestAuthorizationUnauthorized(t *testing.T) { }) }) } + +// TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required. +func TestAuthorizationRequireOAuthFalse(t *testing.T) { + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) { + resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress)) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close() }) + t.Run("Protected resource with MISSING Authorization header returns 200 - OK)", func(t *testing.T) { + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) + } + }) + }) +} + +func TestAuthorizationRawToken(t *testing.T) { + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) { + ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tokenReviewSuccessful)) + return + } + })) + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close() }) + t.Run("Protected resource with VALID Authorization header returns 200 - OK", func(t *testing.T) { + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) + } + }) + }) +} + +func TestAuthorizationOidcToken(t *testing.T) { + key, oidcProvider, httpServer := NewOidcTestServer(t) + t.Cleanup(httpServer.Close) + rawClaims := `{ + "iss": "` + httpServer.URL + `", + "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `, + "aud": "mcp-server" + }` + validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims) + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) { + ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tokenReviewSuccessful)) + return + } + })) + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+validOidcToken) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get protected endpoint: %v", err) + } + t.Cleanup(func() { _ = resp.Body.Close() }) + t.Run("Protected resource with VALID OIDC Authorization header returns 200 - OK", func(t *testing.T) { + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) + } + }) + }) +} diff --git a/pkg/kubernetes/token.go b/pkg/kubernetes/token.go index bac697c..d81f413 100644 --- a/pkg/kubernetes/token.go +++ b/pkg/kubernetes/token.go @@ -3,6 +3,7 @@ package kubernetes import ( "context" "fmt" + authenticationv1api "k8s.io/api/authentication/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index cc34236..729cab2 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -122,9 +122,9 @@ func (s *Server) ServeHTTP(httpServer *http.Server) *server.StreamableHTTPServer return server.NewStreamableHTTPServer(s.server, options...) } -// VerifyTokenAPIServer verifies the given token with the audience by +// KubernetesApiVerifyToken verifies the given token with the audience by // sending an TokenReview request to API Server. -func (s *Server) VerifyTokenAPIServer(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) { +func (s *Server) KubernetesApiVerifyToken(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) { if s.k == nil { return nil, nil, fmt.Errorf("kubernetes manager is not initialized") }