feat(auth): introduce require-oauth flag to comply with OAuth in MCP specification (170)

Introduce require-oauth flag

When this flag is enabled, authorization middleware will be turned on.
When this flag is enabled, Derived which is generated based on the client
token will not be used.
---
Wire Authorization middleware to http mux

This commit adds authorization middleware. Additionally, this commit
rejects the requests if the bearer token is absent in Authorization
header of the request.
---
Add offline token validation for expiration and audience

Per Model Context Protocol specification, MCP Servers must check the
audience field of the token to ensure that they are generated specifically
for them.

This commits parses the JWT token and asserts that audience is correct
and token is not expired.
---
Add online token verification via TokenReview request to API Server

This commit sends online token verification by sending request to
TokenReview endpoint of API Server with the token and expected audience.

If API Server returns the status as authenticated, that means this token
can be used to generate a new ad hoc token for MCP Server.

If API Server returns the status as not authenticated, that means this token
is invalid and MCP Server returns 401 to force the client to initiate OAuth flow.
---
Serve oauth protected resource metadata endpoint
---
Introduce server-url to be represented in protected resource metadata
---
Add error return type in Derived function
---
Return error if error occurs in Derived, when require-oauth
---
Add test cases for authorization-url and server-url
---
Wire server-url to audience, if it is set
---
Remove redundant ssebaseurl parameter from http
This commit is contained in:
Arda Güçlü
2025-07-14 07:31:17 +03:00
committed by GitHub
parent 114726fb7c
commit 275b91a00d
17 changed files with 827 additions and 34 deletions

View File

@@ -22,6 +22,9 @@ type StaticConfig struct {
DisableDestructive bool `toml:"disable_destructive,omitempty"`
EnabledTools []string `toml:"enabled_tools,omitempty"`
DisabledTools []string `toml:"disabled_tools,omitempty"`
RequireOAuth bool `toml:"require_oauth,omitempty"`
AuthorizationURL string `toml:"authorization_url,omitempty"`
ServerURL string `toml:"server_url,omitempty"`
}
type GroupVersionKind struct {

120
pkg/http/authorization.go Normal file
View File

@@ -0,0 +1,120 @@
package http
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"slices"
"strings"
"time"
"k8s.io/klog/v2"
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
)
const (
Audience = "kubernetes-mcp-server"
)
// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API
func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp.Server) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/healthz" || r.URL.Path == "/.well-known/oauth-protected-resource" {
next.ServeHTTP(w, r)
return
}
if !requireOAuth {
next.ServeHTTP(w, r)
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
audience := Audience
if serverURL != "" {
audience = serverURL
}
err := validateJWTToken(token, audience)
if err != nil {
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
return
}
// Validate token using Kubernetes TokenReview API
_, _, err = mcpServer.VerifyToken(r.Context(), token, Audience)
if err != nil {
klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
}
type JWTClaims struct {
Issuer string `json:"iss"`
Audience []string `json:"aud"`
ExpiresAt int64 `json:"exp"`
}
// validateJWTToken validates basic JWT claims without signature verification
func validateJWTToken(token, audience string) error {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid JWT token format")
}
claims, err := parseJWTClaims(parts[1])
if err != nil {
return fmt.Errorf("failed to parse JWT claims: %v", err)
}
if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt {
return fmt.Errorf("token expired")
}
if !slices.Contains(claims.Audience, audience) {
return fmt.Errorf("token audience mismatch: %v", claims.Audience)
}
return nil
}
func parseJWTClaims(payload string) (*JWTClaims, error) {
// Add padding if needed
if len(payload)%4 != 0 {
payload += strings.Repeat("=", 4-len(payload)%4)
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
}
return &claims, nil
}

View File

@@ -0,0 +1,272 @@
package http
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestParseJWTClaims(t *testing.T) {
t.Run("valid JWT payload", func(t *testing.T) {
// Sample payload from a valid JWT
payload := "eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxNzUxOTYzOTQ4LCJpYXQiOjE3NTE5NjAzNDgsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MTc1MTk2MDM0OCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9"
claims, err := parseJWTClaims(payload)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if claims == nil {
t.Fatal("expected claims, got nil")
}
if claims.Issuer != "https://kubernetes.default.svc.cluster.local" {
t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", claims.Issuer)
}
expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"}
if len(claims.Audience) != 2 {
t.Errorf("expected 2 audiences, got %d", len(claims.Audience))
}
for i, expected := range expectedAudiences {
if i >= len(claims.Audience) || claims.Audience[i] != expected {
t.Errorf("expected audience[%d] to be %s, got %s", i, expected, claims.Audience[i])
}
}
if claims.ExpiresAt != 1751963948 {
t.Errorf("expected exp 1751963948, got %d", claims.ExpiresAt)
}
})
t.Run("payload needs padding", func(t *testing.T) {
// Create a payload that needs padding
testClaims := JWTClaims{
Issuer: "test-issuer",
Audience: []string{"test-audience"},
ExpiresAt: time.Now().Add(time.Hour).Unix(),
}
jsonBytes, _ := json.Marshal(testClaims)
// Create a payload without proper padding
encodedWithoutPadding := strings.TrimRight(base64.URLEncoding.EncodeToString(jsonBytes), "=")
claims, err := parseJWTClaims(encodedWithoutPadding)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if claims.Issuer != "test-issuer" {
t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer)
}
})
t.Run("invalid base64 payload", func(t *testing.T) {
invalidPayload := "invalid-base64!!!"
_, err := parseJWTClaims(invalidPayload)
if err == nil {
t.Error("expected error for invalid base64, got nil")
}
if !strings.Contains(err.Error(), "failed to decode JWT payload") {
t.Errorf("expected decode error message, got %v", err)
}
})
t.Run("invalid JSON payload", func(t *testing.T) {
// Valid base64 but invalid JSON
invalidJSON := base64.URLEncoding.EncodeToString([]byte("{invalid-json"))
_, err := parseJWTClaims(invalidJSON)
if err == nil {
t.Error("expected error for invalid JSON, got nil")
}
if !strings.Contains(err.Error(), "failed to unmarshal JWT claims") {
t.Errorf("expected unmarshal error message, got %v", err)
}
})
}
func TestValidateJWTToken(t *testing.T) {
t.Run("invalid token format - not enough parts", func(t *testing.T) {
invalidToken := "header.payload"
err := validateJWTToken(invalidToken, "test")
if err == nil {
t.Error("expected error for invalid token format, got nil")
}
if !strings.Contains(err.Error(), "invalid JWT token format") {
t.Errorf("expected format error message, got %v", err)
}
})
t.Run("expired token", func(t *testing.T) {
// Create an expired token
expiredClaims := JWTClaims{
Issuer: "test-issuer",
Audience: []string{"kubernetes-mcp-server"},
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // 1 hour ago
}
jsonBytes, _ := json.Marshal(expiredClaims)
payload := base64.URLEncoding.EncodeToString(jsonBytes)
expiredToken := "header." + payload + ".signature"
err := validateJWTToken(expiredToken, "kubernetes-mcp-server")
if err == nil {
t.Error("expected error for expired token, got nil")
}
if !strings.Contains(err.Error(), "token expired") {
t.Errorf("expected expiration error message, got %v", err)
}
})
t.Run("multiple audiences with correct one", func(t *testing.T) {
// Create a token with multiple audiences including the correct one
multiAudClaims := JWTClaims{
Issuer: "test-issuer",
Audience: []string{"other-audience", "kubernetes-mcp-server", "another-audience"},
ExpiresAt: time.Now().Add(time.Hour).Unix(),
}
jsonBytes, _ := json.Marshal(multiAudClaims)
payload := base64.URLEncoding.EncodeToString(jsonBytes)
multiAudToken := "header." + payload + ".signature"
err := validateJWTToken(multiAudToken, "kubernetes-mcp-server")
if err != nil {
t.Errorf("expected no error for token with multiple audiences, got %v", err)
}
})
t.Run("audience mismatch", func(t *testing.T) {
// Create a token with wrong audience
wrongAudClaims := JWTClaims{
Issuer: "test-issuer",
Audience: []string{"wrong-audience"},
ExpiresAt: time.Now().Add(time.Hour).Unix(),
}
jsonBytes, _ := json.Marshal(wrongAudClaims)
payload := base64.URLEncoding.EncodeToString(jsonBytes)
wrongAudToken := "header." + payload + ".signature"
err := validateJWTToken(wrongAudToken, "audience")
if err == nil {
t.Error("expected error for token with wrong audience, got nil")
}
if !strings.Contains(err.Error(), "audience mismatch") {
t.Errorf("expected audience mismatch error, got %v", err)
}
})
}
func TestAuthorizationMiddleware(t *testing.T) {
// Create a mock handler
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
t.Run("OAuth disabled - passes through", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth disabled
middleware := AuthorizationMiddleware(false, "", nil)
wrappedHandler := middleware(handler)
// Create request without authorization header
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if !handlerCalled {
t.Error("expected handler to be called when OAuth is disabled")
}
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
})
t.Run("healthz endpoint - passes through", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth enabled
middleware := AuthorizationMiddleware(true, "", nil)
wrappedHandler := middleware(handler)
// Create request to healthz endpoint
req := httptest.NewRequest("GET", "/healthz", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if !handlerCalled {
t.Error("expected handler to be called for healthz endpoint")
}
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
})
t.Run("OAuth enabled - missing token", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth enabled
middleware := AuthorizationMiddleware(true, "", nil)
wrappedHandler := middleware(handler)
// Create request without authorization header
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if handlerCalled {
t.Error("expected handler NOT to be called when token is missing")
}
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "Bearer token required") {
t.Errorf("expected bearer token error message, got %s", w.Body.String())
}
})
t.Run("OAuth enabled - invalid token format", func(t *testing.T) {
handlerCalled = false
// Create middleware with OAuth enabled
middleware := AuthorizationMiddleware(true, "", nil)
wrappedHandler := middleware(handler)
// Create request with invalid bearer token
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer invalid-token")
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
if handlerCalled {
t.Error("expected handler NOT to be called when token is invalid")
}
if w.Code != http.StatusUnauthorized {
t.Errorf("expected status 401, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "Invalid token") {
t.Errorf("expected invalid token error message, got %s", w.Body.String())
}
})
}

View File

@@ -2,6 +2,7 @@ package http
import (
"context"
"encoding/json"
"errors"
"net/http"
"os"
@@ -11,19 +12,23 @@ import (
"k8s.io/klog/v2"
"github.com/manusa/kubernetes-mcp-server/pkg/config"
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
)
func Serve(ctx context.Context, mcpServer *mcp.Server, port, sseBaseUrl string) error {
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig) error {
mux := http.NewServeMux()
wrappedMux := RequestMiddleware(mux)
wrappedMux := RequestMiddleware(
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, mcpServer)(mux),
)
httpServer := &http.Server{
Addr: ":" + port,
Addr: ":" + staticConfig.Port,
Handler: wrappedMux,
}
sseServer := mcpServer.ServeSse(sseBaseUrl, httpServer)
sseServer := mcpServer.ServeSse(staticConfig.SSEBaseURL, httpServer)
streamableHttpServer := mcpServer.ServeHTTP(httpServer)
mux.Handle("/sse", sseServer)
mux.Handle("/message", sseServer)
@@ -31,6 +36,34 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, port, sseBaseUrl string)
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var authServers []string
if staticConfig.AuthorizationURL != "" {
authServers = []string{staticConfig.AuthorizationURL}
} else {
// Fallback to Kubernetes API server host if authorization_server is not configured
if apiServerHost := mcpServer.GetKubernetesAPIServerHost(); apiServerHost != "" {
authServers = []string{apiServerHost}
}
}
response := map[string]interface{}{
"authorization_servers": authServers,
"scopes_supported": []string{},
"bearer_methods_supported": []string{"header"},
}
if staticConfig.ServerURL != "" {
response["resource"] = staticConfig.ServerURL
}
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -40,7 +73,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, port, sseBaseUrl string)
serverErr := make(chan error, 1)
go func() {
klog.V(0).Infof("Streaming and SSE HTTP servers starting on port %s and paths /mcp, /sse, /message", port)
klog.V(0).Infof("Streaming and SSE HTTP servers starting on port %s and paths /mcp, /sse, /message", staticConfig.Port)
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
serverErr <- err
}

View File

@@ -11,6 +11,11 @@ import (
func RequestMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/healthz" {
next.ServeHTTP(w, r)
return
}
start := time.Now()
lrw := &loggingResponseWriter{

View File

@@ -55,6 +55,9 @@ type MCPServerOptions struct {
ListOutput string
ReadOnly bool
DisableDestructive bool
RequireOAuth bool
AuthorizationURL string
ServerURL string
ConfigPath string
StaticConfig *config.StaticConfig
@@ -107,7 +110,12 @@ func NewMCPServer(streams genericiooptions.IOStreams) *cobra.Command {
cmd.Flags().StringVar(&o.ListOutput, "list-output", o.ListOutput, "Output format for resource list operations (one of: "+strings.Join(output.Names, ", ")+"). Defaults to table.")
cmd.Flags().BoolVar(&o.ReadOnly, "read-only", o.ReadOnly, "If true, only tools annotated with readOnlyHint=true are exposed")
cmd.Flags().BoolVar(&o.DisableDestructive, "disable-destructive", o.DisableDestructive, "If true, tools annotated with destructiveHint=true are disabled")
cmd.Flags().BoolVar(&o.RequireOAuth, "require-oauth", o.RequireOAuth, "If true, requires OAuth authorization as defined in the Model Context Protocol (MCP) specification. This flag is ignored if transport type is stdio")
cmd.Flags().MarkHidden("require-oauth")
cmd.Flags().StringVar(&o.AuthorizationURL, "authorization-url", o.AuthorizationURL, "OAuth authorization server URL for protected resource endpoint. If not provided, the Kubernetes API server host will be used. Only valid if require-oauth is enabled.")
cmd.Flags().MarkHidden("authorization-url")
cmd.Flags().StringVar(&o.ServerURL, "server-url", o.ServerURL, "Server URL of this application. Optional. If set, this url will be served in protected resource metadata endpoint and tokens will be validated with this audience. If not set, expected audience is kubernetes-mcp-server. Only valid if require-oauth is enabled.")
cmd.Flags().MarkHidden("server-url")
return cmd
}
@@ -124,6 +132,11 @@ func (m *MCPServerOptions) Complete(cmd *cobra.Command) error {
m.initializeLogging()
if m.StaticConfig.RequireOAuth && m.StaticConfig.Port == "" {
// RequireOAuth is not relevant flow for STDIO transport
m.StaticConfig.RequireOAuth = false
}
return nil
}
@@ -153,6 +166,15 @@ func (m *MCPServerOptions) loadFlags(cmd *cobra.Command) {
if cmd.Flag("disable-destructive").Changed {
m.StaticConfig.DisableDestructive = m.DisableDestructive
}
if cmd.Flag("require-oauth").Changed {
m.StaticConfig.RequireOAuth = m.RequireOAuth
}
if cmd.Flag("authorization-url").Changed {
m.StaticConfig.AuthorizationURL = m.AuthorizationURL
}
if cmd.Flag("server-url").Changed {
m.StaticConfig.ServerURL = m.ServerURL
}
}
func (m *MCPServerOptions) initializeLogging() {
@@ -171,6 +193,25 @@ func (m *MCPServerOptions) Validate() error {
if m.Port != "" && (m.SSEPort > 0 || m.HttpPort > 0) {
return fmt.Errorf("--port is mutually exclusive with deprecated --http-port and --sse-port flags")
}
if !m.StaticConfig.RequireOAuth && (m.StaticConfig.AuthorizationURL != "" || m.StaticConfig.ServerURL != "") {
return fmt.Errorf("authorization-url and server-url are only valid if require-oauth is enabled")
}
if m.StaticConfig.AuthorizationURL != "" &&
!strings.HasPrefix(m.StaticConfig.AuthorizationURL, "https://") {
if strings.HasPrefix(m.StaticConfig.AuthorizationURL, "http://") {
klog.Warningf("authorization-url is using http://, this is not recommended production use")
} else {
return fmt.Errorf("authorization-url must start with https://")
}
}
if m.StaticConfig.ServerURL != "" &&
!strings.HasPrefix(m.StaticConfig.ServerURL, "https://") {
if strings.HasPrefix(m.StaticConfig.ServerURL, "http://") {
klog.Warningf("server-url is using http://, this is not recommended production use")
} else {
return fmt.Errorf("server-url must start with https://")
}
}
return nil
}
@@ -206,7 +247,7 @@ func (m *MCPServerOptions) Run() error {
if m.StaticConfig.Port != "" {
ctx := context.Background()
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig.Port, m.SSEBaseUrl)
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig)
}
if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) {

View File

@@ -230,3 +230,53 @@ func TestDisableDestructive(t *testing.T) {
}
})
}
func TestAuthorizationURL(t *testing.T) {
t.Run("invalid authorization-url without protocol", func(t *testing.T) {
ioStreams, _ := testStream()
rootCmd := NewMCPServer(ioStreams)
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--authorization-url", "example.com/auth", "--server-url", "https://example.com:8080"})
err := rootCmd.Execute()
if err == nil {
t.Fatal("Expected error for invalid authorization-url without protocol, got nil")
}
expected := "authorization-url must start with https://"
if !strings.Contains(err.Error(), expected) {
t.Fatalf("Expected error to contain %s, got %s", expected, err.Error())
}
})
t.Run("valid authorization-url with https", func(t *testing.T) {
ioStreams, _ := testStream()
rootCmd := NewMCPServer(ioStreams)
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--authorization-url", "https://example.com/auth", "--server-url", "https://example.com:8080"})
err := rootCmd.Execute()
if err != nil {
t.Fatalf("Expected no error for valid https authorization-url, got %s", err.Error())
}
})
}
func TestServerURL(t *testing.T) {
t.Run("invalid server-url without protocol", func(t *testing.T) {
ioStreams, _ := testStream()
rootCmd := NewMCPServer(ioStreams)
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--server-url", "example.com:8080", "--authorization-url", "https://example.com/auth"})
err := rootCmd.Execute()
if err == nil {
t.Fatal("Expected error for invalid server-url without protocol, got nil")
}
expected := "server-url must start with https://"
if !strings.Contains(err.Error(), expected) {
t.Fatalf("Expected error to contain %s, got %s", expected, err.Error())
}
})
t.Run("valid server-url with https", func(t *testing.T) {
ioStreams, _ := testStream()
rootCmd := NewMCPServer(ioStreams)
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--server-url", "https://example.com:8080", "--authorization-url", "https://example.com/auth"})
err := rootCmd.Execute()
if err != nil {
t.Fatalf("Expected no error for valid https server-url, got %s", err.Error())
}
})
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
authenticationv1api "k8s.io/api/authentication/v1"
authorizationv1api "k8s.io/api/authorization/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -11,6 +12,7 @@ import (
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/client-go/discovery"
"k8s.io/client-go/kubernetes"
authenticationv1 "k8s.io/client-go/kubernetes/typed/authentication/v1"
authorizationv1 "k8s.io/client-go/kubernetes/typed/authorization/v1"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/rest"
@@ -111,6 +113,15 @@ func (a *AccessControlClientset) SelfSubjectAccessReviews() (authorizationv1.Sel
return a.delegate.AuthorizationV1().SelfSubjectAccessReviews(), nil
}
// TokenReview returns TokenReviewInterface
func (a *AccessControlClientset) TokenReview() (authenticationv1.TokenReviewInterface, error) {
gvk := &schema.GroupVersionKind{Group: authenticationv1api.GroupName, Version: authorizationv1api.SchemeGroupVersion.Version, Kind: "TokenReview"}
if !isAllowed(a.staticConfig, gvk) {
return nil, isNotAllowedError(gvk)
}
return a.delegate.AuthenticationV1().TokenReviews(), nil
}
func NewAccessControlClientset(cfg *rest.Config, staticConfig *config.StaticConfig) (*AccessControlClientset, error) {
clientSet, err := kubernetes.NewForConfig(cfg)
if err != nil {

View File

@@ -2,6 +2,7 @@ package kubernetes
import (
"context"
"errors"
"strings"
"k8s.io/apimachinery/pkg/runtime"
@@ -125,6 +126,13 @@ func (m *Manager) Close() {
}
}
func (m *Manager) GetAPIServerHost() string {
if m.cfg == nil {
return ""
}
return m.cfg.Host
}
func (m *Manager) ToDiscoveryClient() (discovery.CachedDiscoveryInterface, error) {
return m.discoveryClient, nil
}
@@ -133,10 +141,13 @@ func (m *Manager) ToRESTMapper() (meta.RESTMapper, error) {
return m.accessControlRESTMapper, nil
}
func (m *Manager) Derived(ctx context.Context) *Kubernetes {
func (m *Manager) Derived(ctx context.Context) (*Kubernetes, error) {
authorization, ok := ctx.Value(OAuthAuthorizationHeader).(string)
if !ok || !strings.HasPrefix(authorization, "Bearer ") {
return &Kubernetes{manager: m}
if m.staticConfig.RequireOAuth {
return nil, errors.New("oauth token required")
}
return &Kubernetes{manager: m}, nil
}
klog.V(5).Infof("%s header found (Bearer), using provided bearer token", OAuthAuthorizationHeader)
derivedCfg := &rest.Config{
@@ -159,7 +170,11 @@ func (m *Manager) Derived(ctx context.Context) *Kubernetes {
}
clientCmdApiConfig, err := m.clientCmdConfig.RawConfig()
if err != nil {
return &Kubernetes{manager: m}
if m.staticConfig.RequireOAuth {
klog.Errorf("failed to get kubeconfig: %v", err)
return nil, errors.New("failed to get kubeconfig")
}
return &Kubernetes{manager: m}, nil
}
clientCmdApiConfig.AuthInfos = make(map[string]*clientcmdapi.AuthInfo)
derived := &Kubernetes{manager: &Manager{
@@ -169,7 +184,11 @@ func (m *Manager) Derived(ctx context.Context) *Kubernetes {
}}
derived.manager.accessControlClientSet, err = NewAccessControlClientset(derived.manager.cfg, derived.manager.staticConfig)
if err != nil {
return &Kubernetes{manager: m}
if m.staticConfig.RequireOAuth {
klog.Errorf("failed to get kubeconfig: %v", err)
return nil, errors.New("failed to get kubeconfig")
}
return &Kubernetes{manager: m}, nil
}
derived.manager.discoveryClient = memory.NewMemCacheClient(derived.manager.accessControlClientSet.DiscoveryClient())
derived.manager.accessControlRESTMapper = NewAccessControlRESTMapper(
@@ -178,9 +197,13 @@ func (m *Manager) Derived(ctx context.Context) *Kubernetes {
)
derived.manager.dynamicClient, err = dynamic.NewForConfig(derived.manager.cfg)
if err != nil {
return &Kubernetes{manager: m}
if m.staticConfig.RequireOAuth {
klog.Errorf("failed to initialize dynamic client: %v", err)
return nil, errors.New("failed to initialize dynamic client")
}
return &Kubernetes{manager: m}, nil
}
return derived
return derived, nil
}
func (k *Kubernetes) NewHelm() *helm.Helm {

View File

@@ -51,7 +51,10 @@ users:
}
defer testManager.Close()
ctx := context.Background()
derived := testManager.Derived(ctx)
derived, err := testManager.Derived(ctx)
if err != nil {
t.Fatalf("failed to create manager: %v", err)
}
if derived.manager != testManager {
t.Errorf("expected original manager, got different manager")
@@ -73,7 +76,10 @@ users:
}
defer testManager.Close()
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "invalid-token")
derived := testManager.Derived(ctx)
derived, err := testManager.Derived(ctx)
if err != nil {
t.Fatalf("failed to create manager: %v", err)
}
if derived.manager != testManager {
t.Errorf("expected original manager, got different manager")
@@ -96,7 +102,10 @@ users:
defer testManager.Close()
testBearerToken := "test-bearer-token-123"
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "Bearer "+testBearerToken)
derived := testManager.Derived(ctx)
derived, err := testManager.Derived(ctx)
if err != nil {
t.Fatalf("failed to create manager: %v", err)
}
if derived.manager == testManager {
t.Errorf("expected new derived manager, got original manager")
@@ -208,4 +217,100 @@ users:
t.Error("expected dynamicClient to be initialized")
}
})
t.Run("with RequireOAuth=true and no authorization header returns oauth token required error", func(t *testing.T) {
testStaticConfig := &config.StaticConfig{
KubeConfig: kubeconfigPath,
RequireOAuth: true,
DisabledTools: []string{"configuration_view"},
DeniedResources: []config.GroupVersionKind{
{Group: "apps", Version: "v1", Kind: "Deployment"},
},
}
testManager, err := NewManager(testStaticConfig)
if err != nil {
t.Fatalf("failed to create manager: %v", err)
}
defer testManager.Close()
ctx := context.Background()
derived, err := testManager.Derived(ctx)
if err == nil {
t.Fatal("expected error for missing oauth token, got nil")
}
if err.Error() != "oauth token required" {
t.Fatalf("expected error 'oauth token required', got %s", err.Error())
}
if derived != nil {
t.Error("expected nil derived manager when oauth token required")
}
})
t.Run("with RequireOAuth=true and invalid authorization header returns oauth token required error", func(t *testing.T) {
testStaticConfig := &config.StaticConfig{
KubeConfig: kubeconfigPath,
RequireOAuth: true,
DisabledTools: []string{"configuration_view"},
DeniedResources: []config.GroupVersionKind{
{Group: "apps", Version: "v1", Kind: "Deployment"},
},
}
testManager, err := NewManager(testStaticConfig)
if err != nil {
t.Fatalf("failed to create manager: %v", err)
}
defer testManager.Close()
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "invalid-token")
derived, err := testManager.Derived(ctx)
if err == nil {
t.Fatal("expected error for invalid oauth token, got nil")
}
if err.Error() != "oauth token required" {
t.Fatalf("expected error 'oauth token required', got %s", err.Error())
}
if derived != nil {
t.Error("expected nil derived manager when oauth token required")
}
})
t.Run("with RequireOAuth=true and valid bearer token creates derived manager", func(t *testing.T) {
testStaticConfig := &config.StaticConfig{
KubeConfig: kubeconfigPath,
RequireOAuth: true,
DisabledTools: []string{"configuration_view"},
DeniedResources: []config.GroupVersionKind{
{Group: "apps", Version: "v1", Kind: "Deployment"},
},
}
testManager, err := NewManager(testStaticConfig)
if err != nil {
t.Fatalf("failed to create manager: %v", err)
}
defer testManager.Close()
testBearerToken := "test-bearer-token-123"
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "Bearer "+testBearerToken)
derived, err := testManager.Derived(ctx)
if err != nil {
t.Fatalf("failed to create manager: %v", err)
}
if derived.manager == testManager {
t.Error("expected new derived manager, got original manager")
}
if derived.manager.staticConfig != testStaticConfig {
t.Error("staticConfig not properly wired to derived manager")
}
derivedCfg := derived.manager.cfg
if derivedCfg == nil {
t.Fatal("derived config is nil")
}
if derivedCfg.BearerToken != testBearerToken {
t.Errorf("expected BearerToken %s, got %s", testBearerToken, derivedCfg.BearerToken)
}
})
}

39
pkg/kubernetes/token.go Normal file
View File

@@ -0,0 +1,39 @@
package kubernetes
import (
"context"
"fmt"
authenticationv1api "k8s.io/api/authentication/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func (m *Manager) VerifyToken(ctx context.Context, token, audience string) (*authenticationv1api.UserInfo, []string, error) {
tokenReviewClient, err := m.accessControlClientSet.TokenReview()
if err != nil {
return nil, nil, err
}
tokenReview := &authenticationv1api.TokenReview{
TypeMeta: metav1.TypeMeta{
APIVersion: "authentication.k8s.io/v1",
Kind: "TokenReview",
},
Spec: authenticationv1api.TokenReviewSpec{
Token: token,
Audiences: []string{audience},
},
}
result, err := tokenReviewClient.Create(ctx, tokenReview, metav1.CreateOptions{})
if err != nil {
return nil, nil, fmt.Errorf("failed to create token review: %v", err)
}
if !result.Status.Authenticated {
if result.Status.Error != "" {
return nil, nil, fmt.Errorf("token authentication failed: %s", result.Status.Error)
}
return nil, nil, fmt.Errorf("token authentication failed")
}
return &result.Status.User, result.Status.Audiences, nil
}

View File

@@ -30,7 +30,11 @@ func (s *Server) eventsList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
if namespace == nil {
namespace = ""
}
eventMap, err := s.k.Derived(ctx).EventsList(ctx, namespace.(string))
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
eventMap, err := derived.EventsList(ctx, namespace.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list events in all namespaces: %v", err)), nil
}

View File

@@ -65,7 +65,11 @@ func (s *Server) helmInstall(ctx context.Context, ctr mcp.CallToolRequest) (*mcp
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
namespace = v
}
ret, err := s.k.Derived(ctx).NewHelm().Install(ctx, chart, values, name, namespace)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.NewHelm().Install(ctx, chart, values, name, namespace)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to install helm chart '%s': %w", chart, err)), nil
}
@@ -81,7 +85,11 @@ func (s *Server) helmList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Ca
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
namespace = v
}
ret, err := s.k.Derived(ctx).NewHelm().List(namespace, allNamespaces)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.NewHelm().List(namespace, allNamespaces)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list helm releases in namespace '%s': %w", namespace, err)), nil
}
@@ -98,7 +106,11 @@ func (s *Server) helmUninstall(ctx context.Context, ctr mcp.CallToolRequest) (*m
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
namespace = v
}
ret, err := s.k.Derived(ctx).NewHelm().Uninstall(name, namespace)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.NewHelm().Uninstall(name, namespace)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to uninstall helm chart '%s': %w", name, err)), nil
}

View File

@@ -2,11 +2,13 @@ package mcp
import (
"context"
"fmt"
"net/http"
"slices"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
authenticationapiv1 "k8s.io/api/authentication/v1"
"k8s.io/utils/ptr"
"github.com/manusa/kubernetes-mcp-server/pkg/config"
@@ -103,6 +105,23 @@ func (s *Server) ServeHTTP(httpServer *http.Server) *server.StreamableHTTPServer
return server.NewStreamableHTTPServer(s.server, options...)
}
// VerifyToken verifies the given token with the audience by
// sending an TokenReview request to API Server.
func (s *Server) VerifyToken(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) {
if s.k == nil {
return nil, nil, fmt.Errorf("kubernetes manager is not initialized")
}
return s.k.VerifyToken(ctx, token, audience)
}
// GetKubernetesAPIServerHost returns the Kubernetes API server host from the configuration.
func (s *Server) GetKubernetesAPIServerHost() string {
if s.k == nil {
return ""
}
return s.k.GetAPIServerHost()
}
func (s *Server) Close() {
if s.k != nil {
s.k.Close()

View File

@@ -38,7 +38,11 @@ func (s *Server) initNamespaces() []server.ServerTool {
}
func (s *Server) namespacesList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
ret, err := s.k.Derived(ctx).NamespacesList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.NamespacesList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list namespaces: %v", err)), nil
}
@@ -46,7 +50,11 @@ func (s *Server) namespacesList(ctx context.Context, _ mcp.CallToolRequest) (*mc
}
func (s *Server) projectsList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
ret, err := s.k.Derived(ctx).ProjectsList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.ProjectsList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list projects: %v", err)), nil
}

View File

@@ -129,7 +129,11 @@ func (s *Server) podsListInAllNamespaces(ctx context.Context, ctr mcp.CallToolRe
if labelSelector != nil {
resourceListOptions.ListOptions.LabelSelector = labelSelector.(string)
}
ret, err := s.k.Derived(ctx).PodsListInAllNamespaces(ctx, resourceListOptions)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.PodsListInAllNamespaces(ctx, resourceListOptions)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list pods in all namespaces: %v", err)), nil
}
@@ -148,7 +152,11 @@ func (s *Server) podsListInNamespace(ctx context.Context, ctr mcp.CallToolReques
if labelSelector != nil {
resourceListOptions.ListOptions.LabelSelector = labelSelector.(string)
}
ret, err := s.k.Derived(ctx).PodsListInNamespace(ctx, ns.(string), resourceListOptions)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.PodsListInNamespace(ctx, ns.(string), resourceListOptions)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list pods in namespace %s: %v", ns, err)), nil
}
@@ -164,7 +172,11 @@ func (s *Server) podsGet(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
if name == nil {
return NewTextResult("", errors.New("failed to get pod, missing argument name")), nil
}
ret, err := s.k.Derived(ctx).PodsGet(ctx, ns.(string), name.(string))
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.PodsGet(ctx, ns.(string), name.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get pod %s in namespace %s: %v", name, ns, err)), nil
}
@@ -180,7 +192,11 @@ func (s *Server) podsDelete(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
if name == nil {
return NewTextResult("", errors.New("failed to delete pod, missing argument name")), nil
}
ret, err := s.k.Derived(ctx).PodsDelete(ctx, ns.(string), name.(string))
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.PodsDelete(ctx, ns.(string), name.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to delete pod %s in namespace %s: %v", name, ns, err)), nil
}
@@ -201,7 +217,11 @@ func (s *Server) podsTop(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
if v, ok := ctr.GetArguments()["label_selector"].(string); ok {
podsTopOptions.LabelSelector = v
}
ret, err := s.k.Derived(ctx).PodsTop(ctx, podsTopOptions)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.PodsTop(ctx, podsTopOptions)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get pods top: %v", err)), nil
}
@@ -238,7 +258,11 @@ func (s *Server) podsExec(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Ca
} else {
return NewTextResult("", errors.New("failed to exec in pod, invalid command argument")), nil
}
ret, err := s.k.Derived(ctx).PodsExec(ctx, ns.(string), name.(string), container.(string), command)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.PodsExec(ctx, ns.(string), name.(string), container.(string), command)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to exec in pod %s in namespace %s: %v", name, ns, err)), nil
} else if ret == "" {
@@ -260,7 +284,11 @@ func (s *Server) podsLog(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
if container == nil {
container = ""
}
ret, err := s.k.Derived(ctx).PodsLog(ctx, ns.(string), name.(string), container.(string))
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.PodsLog(ctx, ns.(string), name.(string), container.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get pod %s log in namespace %s: %v", name, ns, err)), nil
} else if ret == "" {
@@ -286,7 +314,11 @@ func (s *Server) podsRun(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
if port == nil {
port = float64(0)
}
resources, err := s.k.Derived(ctx).PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
resources, err := derived.PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to run pod %s in namespace %s: %v", name, ns, err)), nil
}

View File

@@ -128,7 +128,11 @@ func (s *Server) resourcesList(ctx context.Context, ctr mcp.CallToolRequest) (*m
return NewTextResult("", fmt.Errorf("namespace is not a string")), nil
}
ret, err := s.k.Derived(ctx).ResourcesList(ctx, gvk, ns, resourceListOptions)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.ResourcesList(ctx, gvk, ns, resourceListOptions)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list resources: %v", err)), nil
}
@@ -159,7 +163,11 @@ func (s *Server) resourcesGet(ctx context.Context, ctr mcp.CallToolRequest) (*mc
return NewTextResult("", fmt.Errorf("name is not a string")), nil
}
ret, err := s.k.Derived(ctx).ResourcesGet(ctx, gvk, ns, n)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
ret, err := derived.ResourcesGet(ctx, gvk, ns, n)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get resource: %v", err)), nil
}
@@ -177,7 +185,11 @@ func (s *Server) resourcesCreateOrUpdate(ctx context.Context, ctr mcp.CallToolRe
return NewTextResult("", fmt.Errorf("resource is not a string")), nil
}
resources, err := s.k.Derived(ctx).ResourcesCreateOrUpdate(ctx, r)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
resources, err := derived.ResourcesCreateOrUpdate(ctx, r)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to create or update resources: %v", err)), nil
}
@@ -212,7 +224,11 @@ func (s *Server) resourcesDelete(ctx context.Context, ctr mcp.CallToolRequest) (
return NewTextResult("", fmt.Errorf("name is not a string")), nil
}
err = s.k.Derived(ctx).ResourcesDelete(ctx, gvk, ns, n)
derived, err := s.k.Derived(ctx)
if err != nil {
return nil, err
}
err = derived.ResourcesDelete(ctx, gvk, ns, n)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to delete resource: %v", err)), nil
}