mirror of
https://github.com/containers/kubernetes-mcp-server.git
synced 2025-10-23 01:22:57 +03:00
feat(auth): introduce require-oauth flag to comply with OAuth in MCP specification (170)
Introduce require-oauth flag When this flag is enabled, authorization middleware will be turned on. When this flag is enabled, Derived which is generated based on the client token will not be used. --- Wire Authorization middleware to http mux This commit adds authorization middleware. Additionally, this commit rejects the requests if the bearer token is absent in Authorization header of the request. --- Add offline token validation for expiration and audience Per Model Context Protocol specification, MCP Servers must check the audience field of the token to ensure that they are generated specifically for them. This commits parses the JWT token and asserts that audience is correct and token is not expired. --- Add online token verification via TokenReview request to API Server This commit sends online token verification by sending request to TokenReview endpoint of API Server with the token and expected audience. If API Server returns the status as authenticated, that means this token can be used to generate a new ad hoc token for MCP Server. If API Server returns the status as not authenticated, that means this token is invalid and MCP Server returns 401 to force the client to initiate OAuth flow. --- Serve oauth protected resource metadata endpoint --- Introduce server-url to be represented in protected resource metadata --- Add error return type in Derived function --- Return error if error occurs in Derived, when require-oauth --- Add test cases for authorization-url and server-url --- Wire server-url to audience, if it is set --- Remove redundant ssebaseurl parameter from http
This commit is contained in:
@@ -22,6 +22,9 @@ type StaticConfig struct {
|
||||
DisableDestructive bool `toml:"disable_destructive,omitempty"`
|
||||
EnabledTools []string `toml:"enabled_tools,omitempty"`
|
||||
DisabledTools []string `toml:"disabled_tools,omitempty"`
|
||||
RequireOAuth bool `toml:"require_oauth,omitempty"`
|
||||
AuthorizationURL string `toml:"authorization_url,omitempty"`
|
||||
ServerURL string `toml:"server_url,omitempty"`
|
||||
}
|
||||
|
||||
type GroupVersionKind struct {
|
||||
|
||||
120
pkg/http/authorization.go
Normal file
120
pkg/http/authorization.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"k8s.io/klog/v2"
|
||||
|
||||
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
Audience = "kubernetes-mcp-server"
|
||||
)
|
||||
|
||||
// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API
|
||||
func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp.Server) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/healthz" || r.URL.Path == "/.well-known/oauth-protected-resource" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if !requireOAuth {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
|
||||
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
audience := Audience
|
||||
if serverURL != "" {
|
||||
audience = serverURL
|
||||
}
|
||||
|
||||
err := validateJWTToken(token, audience)
|
||||
if err != nil {
|
||||
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
|
||||
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token using Kubernetes TokenReview API
|
||||
_, _, err = mcpServer.VerifyToken(r.Context(), token, Audience)
|
||||
if err != nil {
|
||||
klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
|
||||
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type JWTClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Audience []string `json:"aud"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
}
|
||||
|
||||
// validateJWTToken validates basic JWT claims without signature verification
|
||||
func validateJWTToken(token, audience string) error {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid JWT token format")
|
||||
}
|
||||
|
||||
claims, err := parseJWTClaims(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT claims: %v", err)
|
||||
}
|
||||
|
||||
if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt {
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, audience) {
|
||||
return fmt.Errorf("token audience mismatch: %v", claims.Audience)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseJWTClaims(payload string) (*JWTClaims, error) {
|
||||
// Add padding if needed
|
||||
if len(payload)%4 != 0 {
|
||||
payload += strings.Repeat("=", 4-len(payload)%4)
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %v", err)
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
272
pkg/http/authorization_test.go
Normal file
272
pkg/http/authorization_test.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseJWTClaims(t *testing.T) {
|
||||
t.Run("valid JWT payload", func(t *testing.T) {
|
||||
// Sample payload from a valid JWT
|
||||
payload := "eyJhdWQiOlsiaHR0cHM6Ly9rdWJlcm5ldGVzLmRlZmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWwiLCJrdWJlcm5ldGVzLW1jcC1zZXJ2ZXIiXSwiZXhwIjoxNzUxOTYzOTQ4LCJpYXQiOjE3NTE5NjAzNDgsImlzcyI6Imh0dHBzOi8va3ViZXJuZXRlcy5kZWZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsIiwianRpIjoiOTkyMjJkNTYtMzQwZS00ZWI2LTg1ODgtMjYxNDExZjM1ZDI2Iiwia3ViZXJuZXRlcy5pbyI6eyJuYW1lc3BhY2UiOiJkZWZhdWx0Iiwic2VydmljZWFjY291bnQiOnsibmFtZSI6ImRlZmF1bHQiLCJ1aWQiOiJlYWNiNmFkMi04MGI3LTQxNzktODQzZC05MmViMWU2YmJiYTYifX0sIm5iZiI6MTc1MTk2MDM0OCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OmRlZmF1bHQ6ZGVmYXVsdCJ9"
|
||||
|
||||
claims, err := parseJWTClaims(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if claims == nil {
|
||||
t.Fatal("expected claims, got nil")
|
||||
}
|
||||
|
||||
if claims.Issuer != "https://kubernetes.default.svc.cluster.local" {
|
||||
t.Errorf("expected issuer 'https://kubernetes.default.svc.cluster.local', got %s", claims.Issuer)
|
||||
}
|
||||
|
||||
expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"}
|
||||
if len(claims.Audience) != 2 {
|
||||
t.Errorf("expected 2 audiences, got %d", len(claims.Audience))
|
||||
}
|
||||
|
||||
for i, expected := range expectedAudiences {
|
||||
if i >= len(claims.Audience) || claims.Audience[i] != expected {
|
||||
t.Errorf("expected audience[%d] to be %s, got %s", i, expected, claims.Audience[i])
|
||||
}
|
||||
}
|
||||
|
||||
if claims.ExpiresAt != 1751963948 {
|
||||
t.Errorf("expected exp 1751963948, got %d", claims.ExpiresAt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("payload needs padding", func(t *testing.T) {
|
||||
// Create a payload that needs padding
|
||||
testClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"test-audience"},
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(testClaims)
|
||||
// Create a payload without proper padding
|
||||
encodedWithoutPadding := strings.TrimRight(base64.URLEncoding.EncodeToString(jsonBytes), "=")
|
||||
|
||||
claims, err := parseJWTClaims(encodedWithoutPadding)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if claims.Issuer != "test-issuer" {
|
||||
t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid base64 payload", func(t *testing.T) {
|
||||
invalidPayload := "invalid-base64!!!"
|
||||
|
||||
_, err := parseJWTClaims(invalidPayload)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid base64, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "failed to decode JWT payload") {
|
||||
t.Errorf("expected decode error message, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON payload", func(t *testing.T) {
|
||||
// Valid base64 but invalid JSON
|
||||
invalidJSON := base64.URLEncoding.EncodeToString([]byte("{invalid-json"))
|
||||
|
||||
_, err := parseJWTClaims(invalidJSON)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "failed to unmarshal JWT claims") {
|
||||
t.Errorf("expected unmarshal error message, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateJWTToken(t *testing.T) {
|
||||
t.Run("invalid token format - not enough parts", func(t *testing.T) {
|
||||
invalidToken := "header.payload"
|
||||
|
||||
err := validateJWTToken(invalidToken, "test")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid token format, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid JWT token format") {
|
||||
t.Errorf("expected format error message, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
// Create an expired token
|
||||
expiredClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"kubernetes-mcp-server"},
|
||||
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // 1 hour ago
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(expiredClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
expiredToken := "header." + payload + ".signature"
|
||||
|
||||
err := validateJWTToken(expiredToken, "kubernetes-mcp-server")
|
||||
if err == nil {
|
||||
t.Error("expected error for expired token, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "token expired") {
|
||||
t.Errorf("expected expiration error message, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple audiences with correct one", func(t *testing.T) {
|
||||
// Create a token with multiple audiences including the correct one
|
||||
multiAudClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"other-audience", "kubernetes-mcp-server", "another-audience"},
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(multiAudClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
multiAudToken := "header." + payload + ".signature"
|
||||
|
||||
err := validateJWTToken(multiAudToken, "kubernetes-mcp-server")
|
||||
if err != nil {
|
||||
t.Errorf("expected no error for token with multiple audiences, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("audience mismatch", func(t *testing.T) {
|
||||
// Create a token with wrong audience
|
||||
wrongAudClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"wrong-audience"},
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(wrongAudClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
wrongAudToken := "header." + payload + ".signature"
|
||||
|
||||
err := validateJWTToken(wrongAudToken, "audience")
|
||||
if err == nil {
|
||||
t.Error("expected error for token with wrong audience, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "audience mismatch") {
|
||||
t.Errorf("expected audience mismatch error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthorizationMiddleware(t *testing.T) {
|
||||
// Create a mock handler
|
||||
handlerCalled := false
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("OAuth disabled - passes through", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth disabled
|
||||
middleware := AuthorizationMiddleware(false, "", nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request without authorization header
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("expected handler to be called when OAuth is disabled")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("healthz endpoint - passes through", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth enabled
|
||||
middleware := AuthorizationMiddleware(true, "", nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request to healthz endpoint
|
||||
req := httptest.NewRequest("GET", "/healthz", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("expected handler to be called for healthz endpoint")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OAuth enabled - missing token", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth enabled
|
||||
middleware := AuthorizationMiddleware(true, "", nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request without authorization header
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if handlerCalled {
|
||||
t.Error("expected handler NOT to be called when token is missing")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "Bearer token required") {
|
||||
t.Errorf("expected bearer token error message, got %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OAuth enabled - invalid token format", func(t *testing.T) {
|
||||
handlerCalled = false
|
||||
|
||||
// Create middleware with OAuth enabled
|
||||
middleware := AuthorizationMiddleware(true, "", nil)
|
||||
wrappedHandler := middleware(handler)
|
||||
|
||||
// Create request with invalid bearer token
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
if handlerCalled {
|
||||
t.Error("expected handler NOT to be called when token is invalid")
|
||||
}
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "Invalid token") {
|
||||
t.Errorf("expected invalid token error message, got %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -11,19 +12,23 @@ import (
|
||||
|
||||
"k8s.io/klog/v2"
|
||||
|
||||
"github.com/manusa/kubernetes-mcp-server/pkg/config"
|
||||
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
|
||||
)
|
||||
|
||||
func Serve(ctx context.Context, mcpServer *mcp.Server, port, sseBaseUrl string) error {
|
||||
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig) error {
|
||||
mux := http.NewServeMux()
|
||||
wrappedMux := RequestMiddleware(mux)
|
||||
|
||||
wrappedMux := RequestMiddleware(
|
||||
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, mcpServer)(mux),
|
||||
)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: ":" + port,
|
||||
Addr: ":" + staticConfig.Port,
|
||||
Handler: wrappedMux,
|
||||
}
|
||||
|
||||
sseServer := mcpServer.ServeSse(sseBaseUrl, httpServer)
|
||||
sseServer := mcpServer.ServeSse(staticConfig.SSEBaseURL, httpServer)
|
||||
streamableHttpServer := mcpServer.ServeHTTP(httpServer)
|
||||
mux.Handle("/sse", sseServer)
|
||||
mux.Handle("/message", sseServer)
|
||||
@@ -31,6 +36,34 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, port, sseBaseUrl string)
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var authServers []string
|
||||
if staticConfig.AuthorizationURL != "" {
|
||||
authServers = []string{staticConfig.AuthorizationURL}
|
||||
} else {
|
||||
// Fallback to Kubernetes API server host if authorization_server is not configured
|
||||
if apiServerHost := mcpServer.GetKubernetesAPIServerHost(); apiServerHost != "" {
|
||||
authServers = []string{apiServerHost}
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"authorization_servers": authServers,
|
||||
"scopes_supported": []string{},
|
||||
"bearer_methods_supported": []string{"header"},
|
||||
}
|
||||
|
||||
if staticConfig.ServerURL != "" {
|
||||
response["resource"] = staticConfig.ServerURL
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
@@ -40,7 +73,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, port, sseBaseUrl string)
|
||||
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
klog.V(0).Infof("Streaming and SSE HTTP servers starting on port %s and paths /mcp, /sse, /message", port)
|
||||
klog.V(0).Infof("Streaming and SSE HTTP servers starting on port %s and paths /mcp, /sse, /message", staticConfig.Port)
|
||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
serverErr <- err
|
||||
}
|
||||
|
||||
@@ -11,6 +11,11 @@ import (
|
||||
|
||||
func RequestMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/healthz" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
lrw := &loggingResponseWriter{
|
||||
|
||||
@@ -55,6 +55,9 @@ type MCPServerOptions struct {
|
||||
ListOutput string
|
||||
ReadOnly bool
|
||||
DisableDestructive bool
|
||||
RequireOAuth bool
|
||||
AuthorizationURL string
|
||||
ServerURL string
|
||||
|
||||
ConfigPath string
|
||||
StaticConfig *config.StaticConfig
|
||||
@@ -107,7 +110,12 @@ func NewMCPServer(streams genericiooptions.IOStreams) *cobra.Command {
|
||||
cmd.Flags().StringVar(&o.ListOutput, "list-output", o.ListOutput, "Output format for resource list operations (one of: "+strings.Join(output.Names, ", ")+"). Defaults to table.")
|
||||
cmd.Flags().BoolVar(&o.ReadOnly, "read-only", o.ReadOnly, "If true, only tools annotated with readOnlyHint=true are exposed")
|
||||
cmd.Flags().BoolVar(&o.DisableDestructive, "disable-destructive", o.DisableDestructive, "If true, tools annotated with destructiveHint=true are disabled")
|
||||
|
||||
cmd.Flags().BoolVar(&o.RequireOAuth, "require-oauth", o.RequireOAuth, "If true, requires OAuth authorization as defined in the Model Context Protocol (MCP) specification. This flag is ignored if transport type is stdio")
|
||||
cmd.Flags().MarkHidden("require-oauth")
|
||||
cmd.Flags().StringVar(&o.AuthorizationURL, "authorization-url", o.AuthorizationURL, "OAuth authorization server URL for protected resource endpoint. If not provided, the Kubernetes API server host will be used. Only valid if require-oauth is enabled.")
|
||||
cmd.Flags().MarkHidden("authorization-url")
|
||||
cmd.Flags().StringVar(&o.ServerURL, "server-url", o.ServerURL, "Server URL of this application. Optional. If set, this url will be served in protected resource metadata endpoint and tokens will be validated with this audience. If not set, expected audience is kubernetes-mcp-server. Only valid if require-oauth is enabled.")
|
||||
cmd.Flags().MarkHidden("server-url")
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -124,6 +132,11 @@ func (m *MCPServerOptions) Complete(cmd *cobra.Command) error {
|
||||
|
||||
m.initializeLogging()
|
||||
|
||||
if m.StaticConfig.RequireOAuth && m.StaticConfig.Port == "" {
|
||||
// RequireOAuth is not relevant flow for STDIO transport
|
||||
m.StaticConfig.RequireOAuth = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -153,6 +166,15 @@ func (m *MCPServerOptions) loadFlags(cmd *cobra.Command) {
|
||||
if cmd.Flag("disable-destructive").Changed {
|
||||
m.StaticConfig.DisableDestructive = m.DisableDestructive
|
||||
}
|
||||
if cmd.Flag("require-oauth").Changed {
|
||||
m.StaticConfig.RequireOAuth = m.RequireOAuth
|
||||
}
|
||||
if cmd.Flag("authorization-url").Changed {
|
||||
m.StaticConfig.AuthorizationURL = m.AuthorizationURL
|
||||
}
|
||||
if cmd.Flag("server-url").Changed {
|
||||
m.StaticConfig.ServerURL = m.ServerURL
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MCPServerOptions) initializeLogging() {
|
||||
@@ -171,6 +193,25 @@ func (m *MCPServerOptions) Validate() error {
|
||||
if m.Port != "" && (m.SSEPort > 0 || m.HttpPort > 0) {
|
||||
return fmt.Errorf("--port is mutually exclusive with deprecated --http-port and --sse-port flags")
|
||||
}
|
||||
if !m.StaticConfig.RequireOAuth && (m.StaticConfig.AuthorizationURL != "" || m.StaticConfig.ServerURL != "") {
|
||||
return fmt.Errorf("authorization-url and server-url are only valid if require-oauth is enabled")
|
||||
}
|
||||
if m.StaticConfig.AuthorizationURL != "" &&
|
||||
!strings.HasPrefix(m.StaticConfig.AuthorizationURL, "https://") {
|
||||
if strings.HasPrefix(m.StaticConfig.AuthorizationURL, "http://") {
|
||||
klog.Warningf("authorization-url is using http://, this is not recommended production use")
|
||||
} else {
|
||||
return fmt.Errorf("authorization-url must start with https://")
|
||||
}
|
||||
}
|
||||
if m.StaticConfig.ServerURL != "" &&
|
||||
!strings.HasPrefix(m.StaticConfig.ServerURL, "https://") {
|
||||
if strings.HasPrefix(m.StaticConfig.ServerURL, "http://") {
|
||||
klog.Warningf("server-url is using http://, this is not recommended production use")
|
||||
} else {
|
||||
return fmt.Errorf("server-url must start with https://")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -206,7 +247,7 @@ func (m *MCPServerOptions) Run() error {
|
||||
|
||||
if m.StaticConfig.Port != "" {
|
||||
ctx := context.Background()
|
||||
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig.Port, m.SSEBaseUrl)
|
||||
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig)
|
||||
}
|
||||
|
||||
if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
|
||||
@@ -230,3 +230,53 @@ func TestDisableDestructive(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthorizationURL(t *testing.T) {
|
||||
t.Run("invalid authorization-url without protocol", func(t *testing.T) {
|
||||
ioStreams, _ := testStream()
|
||||
rootCmd := NewMCPServer(ioStreams)
|
||||
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--authorization-url", "example.com/auth", "--server-url", "https://example.com:8080"})
|
||||
err := rootCmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid authorization-url without protocol, got nil")
|
||||
}
|
||||
expected := "authorization-url must start with https://"
|
||||
if !strings.Contains(err.Error(), expected) {
|
||||
t.Fatalf("Expected error to contain %s, got %s", expected, err.Error())
|
||||
}
|
||||
})
|
||||
t.Run("valid authorization-url with https", func(t *testing.T) {
|
||||
ioStreams, _ := testStream()
|
||||
rootCmd := NewMCPServer(ioStreams)
|
||||
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--authorization-url", "https://example.com/auth", "--server-url", "https://example.com:8080"})
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error for valid https authorization-url, got %s", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerURL(t *testing.T) {
|
||||
t.Run("invalid server-url without protocol", func(t *testing.T) {
|
||||
ioStreams, _ := testStream()
|
||||
rootCmd := NewMCPServer(ioStreams)
|
||||
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--server-url", "example.com:8080", "--authorization-url", "https://example.com/auth"})
|
||||
err := rootCmd.Execute()
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid server-url without protocol, got nil")
|
||||
}
|
||||
expected := "server-url must start with https://"
|
||||
if !strings.Contains(err.Error(), expected) {
|
||||
t.Fatalf("Expected error to contain %s, got %s", expected, err.Error())
|
||||
}
|
||||
})
|
||||
t.Run("valid server-url with https", func(t *testing.T) {
|
||||
ioStreams, _ := testStream()
|
||||
rootCmd := NewMCPServer(ioStreams)
|
||||
rootCmd.SetArgs([]string{"--version", "--require-oauth", "--port=8080", "--server-url", "https://example.com:8080", "--authorization-url", "https://example.com/auth"})
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error for valid https server-url, got %s", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
authenticationv1api "k8s.io/api/authentication/v1"
|
||||
authorizationv1api "k8s.io/api/authorization/v1"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/client-go/discovery"
|
||||
"k8s.io/client-go/kubernetes"
|
||||
authenticationv1 "k8s.io/client-go/kubernetes/typed/authentication/v1"
|
||||
authorizationv1 "k8s.io/client-go/kubernetes/typed/authorization/v1"
|
||||
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
|
||||
"k8s.io/client-go/rest"
|
||||
@@ -111,6 +113,15 @@ func (a *AccessControlClientset) SelfSubjectAccessReviews() (authorizationv1.Sel
|
||||
return a.delegate.AuthorizationV1().SelfSubjectAccessReviews(), nil
|
||||
}
|
||||
|
||||
// TokenReview returns TokenReviewInterface
|
||||
func (a *AccessControlClientset) TokenReview() (authenticationv1.TokenReviewInterface, error) {
|
||||
gvk := &schema.GroupVersionKind{Group: authenticationv1api.GroupName, Version: authorizationv1api.SchemeGroupVersion.Version, Kind: "TokenReview"}
|
||||
if !isAllowed(a.staticConfig, gvk) {
|
||||
return nil, isNotAllowedError(gvk)
|
||||
}
|
||||
return a.delegate.AuthenticationV1().TokenReviews(), nil
|
||||
}
|
||||
|
||||
func NewAccessControlClientset(cfg *rest.Config, staticConfig *config.StaticConfig) (*AccessControlClientset, error) {
|
||||
clientSet, err := kubernetes.NewForConfig(cfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package kubernetes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
@@ -125,6 +126,13 @@ func (m *Manager) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) GetAPIServerHost() string {
|
||||
if m.cfg == nil {
|
||||
return ""
|
||||
}
|
||||
return m.cfg.Host
|
||||
}
|
||||
|
||||
func (m *Manager) ToDiscoveryClient() (discovery.CachedDiscoveryInterface, error) {
|
||||
return m.discoveryClient, nil
|
||||
}
|
||||
@@ -133,10 +141,13 @@ func (m *Manager) ToRESTMapper() (meta.RESTMapper, error) {
|
||||
return m.accessControlRESTMapper, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Derived(ctx context.Context) *Kubernetes {
|
||||
func (m *Manager) Derived(ctx context.Context) (*Kubernetes, error) {
|
||||
authorization, ok := ctx.Value(OAuthAuthorizationHeader).(string)
|
||||
if !ok || !strings.HasPrefix(authorization, "Bearer ") {
|
||||
return &Kubernetes{manager: m}
|
||||
if m.staticConfig.RequireOAuth {
|
||||
return nil, errors.New("oauth token required")
|
||||
}
|
||||
return &Kubernetes{manager: m}, nil
|
||||
}
|
||||
klog.V(5).Infof("%s header found (Bearer), using provided bearer token", OAuthAuthorizationHeader)
|
||||
derivedCfg := &rest.Config{
|
||||
@@ -159,7 +170,11 @@ func (m *Manager) Derived(ctx context.Context) *Kubernetes {
|
||||
}
|
||||
clientCmdApiConfig, err := m.clientCmdConfig.RawConfig()
|
||||
if err != nil {
|
||||
return &Kubernetes{manager: m}
|
||||
if m.staticConfig.RequireOAuth {
|
||||
klog.Errorf("failed to get kubeconfig: %v", err)
|
||||
return nil, errors.New("failed to get kubeconfig")
|
||||
}
|
||||
return &Kubernetes{manager: m}, nil
|
||||
}
|
||||
clientCmdApiConfig.AuthInfos = make(map[string]*clientcmdapi.AuthInfo)
|
||||
derived := &Kubernetes{manager: &Manager{
|
||||
@@ -169,7 +184,11 @@ func (m *Manager) Derived(ctx context.Context) *Kubernetes {
|
||||
}}
|
||||
derived.manager.accessControlClientSet, err = NewAccessControlClientset(derived.manager.cfg, derived.manager.staticConfig)
|
||||
if err != nil {
|
||||
return &Kubernetes{manager: m}
|
||||
if m.staticConfig.RequireOAuth {
|
||||
klog.Errorf("failed to get kubeconfig: %v", err)
|
||||
return nil, errors.New("failed to get kubeconfig")
|
||||
}
|
||||
return &Kubernetes{manager: m}, nil
|
||||
}
|
||||
derived.manager.discoveryClient = memory.NewMemCacheClient(derived.manager.accessControlClientSet.DiscoveryClient())
|
||||
derived.manager.accessControlRESTMapper = NewAccessControlRESTMapper(
|
||||
@@ -178,9 +197,13 @@ func (m *Manager) Derived(ctx context.Context) *Kubernetes {
|
||||
)
|
||||
derived.manager.dynamicClient, err = dynamic.NewForConfig(derived.manager.cfg)
|
||||
if err != nil {
|
||||
return &Kubernetes{manager: m}
|
||||
if m.staticConfig.RequireOAuth {
|
||||
klog.Errorf("failed to initialize dynamic client: %v", err)
|
||||
return nil, errors.New("failed to initialize dynamic client")
|
||||
}
|
||||
return &Kubernetes{manager: m}, nil
|
||||
}
|
||||
return derived
|
||||
return derived, nil
|
||||
}
|
||||
|
||||
func (k *Kubernetes) NewHelm() *helm.Helm {
|
||||
|
||||
@@ -51,7 +51,10 @@ users:
|
||||
}
|
||||
defer testManager.Close()
|
||||
ctx := context.Background()
|
||||
derived := testManager.Derived(ctx)
|
||||
derived, err := testManager.Derived(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
if derived.manager != testManager {
|
||||
t.Errorf("expected original manager, got different manager")
|
||||
@@ -73,7 +76,10 @@ users:
|
||||
}
|
||||
defer testManager.Close()
|
||||
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "invalid-token")
|
||||
derived := testManager.Derived(ctx)
|
||||
derived, err := testManager.Derived(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
if derived.manager != testManager {
|
||||
t.Errorf("expected original manager, got different manager")
|
||||
@@ -96,7 +102,10 @@ users:
|
||||
defer testManager.Close()
|
||||
testBearerToken := "test-bearer-token-123"
|
||||
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "Bearer "+testBearerToken)
|
||||
derived := testManager.Derived(ctx)
|
||||
derived, err := testManager.Derived(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
if derived.manager == testManager {
|
||||
t.Errorf("expected new derived manager, got original manager")
|
||||
@@ -208,4 +217,100 @@ users:
|
||||
t.Error("expected dynamicClient to be initialized")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with RequireOAuth=true and no authorization header returns oauth token required error", func(t *testing.T) {
|
||||
testStaticConfig := &config.StaticConfig{
|
||||
KubeConfig: kubeconfigPath,
|
||||
RequireOAuth: true,
|
||||
DisabledTools: []string{"configuration_view"},
|
||||
DeniedResources: []config.GroupVersionKind{
|
||||
{Group: "apps", Version: "v1", Kind: "Deployment"},
|
||||
},
|
||||
}
|
||||
|
||||
testManager, err := NewManager(testStaticConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
defer testManager.Close()
|
||||
ctx := context.Background()
|
||||
derived, err := testManager.Derived(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing oauth token, got nil")
|
||||
}
|
||||
if err.Error() != "oauth token required" {
|
||||
t.Fatalf("expected error 'oauth token required', got %s", err.Error())
|
||||
}
|
||||
if derived != nil {
|
||||
t.Error("expected nil derived manager when oauth token required")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with RequireOAuth=true and invalid authorization header returns oauth token required error", func(t *testing.T) {
|
||||
testStaticConfig := &config.StaticConfig{
|
||||
KubeConfig: kubeconfigPath,
|
||||
RequireOAuth: true,
|
||||
DisabledTools: []string{"configuration_view"},
|
||||
DeniedResources: []config.GroupVersionKind{
|
||||
{Group: "apps", Version: "v1", Kind: "Deployment"},
|
||||
},
|
||||
}
|
||||
|
||||
testManager, err := NewManager(testStaticConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
defer testManager.Close()
|
||||
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "invalid-token")
|
||||
derived, err := testManager.Derived(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid oauth token, got nil")
|
||||
}
|
||||
if err.Error() != "oauth token required" {
|
||||
t.Fatalf("expected error 'oauth token required', got %s", err.Error())
|
||||
}
|
||||
if derived != nil {
|
||||
t.Error("expected nil derived manager when oauth token required")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with RequireOAuth=true and valid bearer token creates derived manager", func(t *testing.T) {
|
||||
testStaticConfig := &config.StaticConfig{
|
||||
KubeConfig: kubeconfigPath,
|
||||
RequireOAuth: true,
|
||||
DisabledTools: []string{"configuration_view"},
|
||||
DeniedResources: []config.GroupVersionKind{
|
||||
{Group: "apps", Version: "v1", Kind: "Deployment"},
|
||||
},
|
||||
}
|
||||
|
||||
testManager, err := NewManager(testStaticConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
defer testManager.Close()
|
||||
testBearerToken := "test-bearer-token-123"
|
||||
ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "Bearer "+testBearerToken)
|
||||
derived, err := testManager.Derived(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
if derived.manager == testManager {
|
||||
t.Error("expected new derived manager, got original manager")
|
||||
}
|
||||
|
||||
if derived.manager.staticConfig != testStaticConfig {
|
||||
t.Error("staticConfig not properly wired to derived manager")
|
||||
}
|
||||
|
||||
derivedCfg := derived.manager.cfg
|
||||
if derivedCfg == nil {
|
||||
t.Fatal("derived config is nil")
|
||||
}
|
||||
|
||||
if derivedCfg.BearerToken != testBearerToken {
|
||||
t.Errorf("expected BearerToken %s, got %s", testBearerToken, derivedCfg.BearerToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
39
pkg/kubernetes/token.go
Normal file
39
pkg/kubernetes/token.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package kubernetes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
authenticationv1api "k8s.io/api/authentication/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
)
|
||||
|
||||
func (m *Manager) VerifyToken(ctx context.Context, token, audience string) (*authenticationv1api.UserInfo, []string, error) {
|
||||
tokenReviewClient, err := m.accessControlClientSet.TokenReview()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tokenReview := &authenticationv1api.TokenReview{
|
||||
TypeMeta: metav1.TypeMeta{
|
||||
APIVersion: "authentication.k8s.io/v1",
|
||||
Kind: "TokenReview",
|
||||
},
|
||||
Spec: authenticationv1api.TokenReviewSpec{
|
||||
Token: token,
|
||||
Audiences: []string{audience},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := tokenReviewClient.Create(ctx, tokenReview, metav1.CreateOptions{})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create token review: %v", err)
|
||||
}
|
||||
|
||||
if !result.Status.Authenticated {
|
||||
if result.Status.Error != "" {
|
||||
return nil, nil, fmt.Errorf("token authentication failed: %s", result.Status.Error)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("token authentication failed")
|
||||
}
|
||||
|
||||
return &result.Status.User, result.Status.Audiences, nil
|
||||
}
|
||||
@@ -30,7 +30,11 @@ func (s *Server) eventsList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
|
||||
if namespace == nil {
|
||||
namespace = ""
|
||||
}
|
||||
eventMap, err := s.k.Derived(ctx).EventsList(ctx, namespace.(string))
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventMap, err := derived.EventsList(ctx, namespace.(string))
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to list events in all namespaces: %v", err)), nil
|
||||
}
|
||||
|
||||
@@ -65,7 +65,11 @@ func (s *Server) helmInstall(ctx context.Context, ctr mcp.CallToolRequest) (*mcp
|
||||
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
|
||||
namespace = v
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).NewHelm().Install(ctx, chart, values, name, namespace)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.NewHelm().Install(ctx, chart, values, name, namespace)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to install helm chart '%s': %w", chart, err)), nil
|
||||
}
|
||||
@@ -81,7 +85,11 @@ func (s *Server) helmList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Ca
|
||||
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
|
||||
namespace = v
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).NewHelm().List(namespace, allNamespaces)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.NewHelm().List(namespace, allNamespaces)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to list helm releases in namespace '%s': %w", namespace, err)), nil
|
||||
}
|
||||
@@ -98,7 +106,11 @@ func (s *Server) helmUninstall(ctx context.Context, ctr mcp.CallToolRequest) (*m
|
||||
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
|
||||
namespace = v
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).NewHelm().Uninstall(name, namespace)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.NewHelm().Uninstall(name, namespace)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to uninstall helm chart '%s': %w", name, err)), nil
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@ package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
authenticationapiv1 "k8s.io/api/authentication/v1"
|
||||
"k8s.io/utils/ptr"
|
||||
|
||||
"github.com/manusa/kubernetes-mcp-server/pkg/config"
|
||||
@@ -103,6 +105,23 @@ func (s *Server) ServeHTTP(httpServer *http.Server) *server.StreamableHTTPServer
|
||||
return server.NewStreamableHTTPServer(s.server, options...)
|
||||
}
|
||||
|
||||
// VerifyToken verifies the given token with the audience by
|
||||
// sending an TokenReview request to API Server.
|
||||
func (s *Server) VerifyToken(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) {
|
||||
if s.k == nil {
|
||||
return nil, nil, fmt.Errorf("kubernetes manager is not initialized")
|
||||
}
|
||||
return s.k.VerifyToken(ctx, token, audience)
|
||||
}
|
||||
|
||||
// GetKubernetesAPIServerHost returns the Kubernetes API server host from the configuration.
|
||||
func (s *Server) GetKubernetesAPIServerHost() string {
|
||||
if s.k == nil {
|
||||
return ""
|
||||
}
|
||||
return s.k.GetAPIServerHost()
|
||||
}
|
||||
|
||||
func (s *Server) Close() {
|
||||
if s.k != nil {
|
||||
s.k.Close()
|
||||
|
||||
@@ -38,7 +38,11 @@ func (s *Server) initNamespaces() []server.ServerTool {
|
||||
}
|
||||
|
||||
func (s *Server) namespacesList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
ret, err := s.k.Derived(ctx).NamespacesList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.NamespacesList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to list namespaces: %v", err)), nil
|
||||
}
|
||||
@@ -46,7 +50,11 @@ func (s *Server) namespacesList(ctx context.Context, _ mcp.CallToolRequest) (*mc
|
||||
}
|
||||
|
||||
func (s *Server) projectsList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
ret, err := s.k.Derived(ctx).ProjectsList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.ProjectsList(ctx, kubernetes.ResourceListOptions{AsTable: s.configuration.ListOutput.AsTable()})
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to list projects: %v", err)), nil
|
||||
}
|
||||
|
||||
@@ -129,7 +129,11 @@ func (s *Server) podsListInAllNamespaces(ctx context.Context, ctr mcp.CallToolRe
|
||||
if labelSelector != nil {
|
||||
resourceListOptions.ListOptions.LabelSelector = labelSelector.(string)
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).PodsListInAllNamespaces(ctx, resourceListOptions)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.PodsListInAllNamespaces(ctx, resourceListOptions)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to list pods in all namespaces: %v", err)), nil
|
||||
}
|
||||
@@ -148,7 +152,11 @@ func (s *Server) podsListInNamespace(ctx context.Context, ctr mcp.CallToolReques
|
||||
if labelSelector != nil {
|
||||
resourceListOptions.ListOptions.LabelSelector = labelSelector.(string)
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).PodsListInNamespace(ctx, ns.(string), resourceListOptions)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.PodsListInNamespace(ctx, ns.(string), resourceListOptions)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to list pods in namespace %s: %v", ns, err)), nil
|
||||
}
|
||||
@@ -164,7 +172,11 @@ func (s *Server) podsGet(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
|
||||
if name == nil {
|
||||
return NewTextResult("", errors.New("failed to get pod, missing argument name")), nil
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).PodsGet(ctx, ns.(string), name.(string))
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.PodsGet(ctx, ns.(string), name.(string))
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to get pod %s in namespace %s: %v", name, ns, err)), nil
|
||||
}
|
||||
@@ -180,7 +192,11 @@ func (s *Server) podsDelete(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
|
||||
if name == nil {
|
||||
return NewTextResult("", errors.New("failed to delete pod, missing argument name")), nil
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).PodsDelete(ctx, ns.(string), name.(string))
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.PodsDelete(ctx, ns.(string), name.(string))
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to delete pod %s in namespace %s: %v", name, ns, err)), nil
|
||||
}
|
||||
@@ -201,7 +217,11 @@ func (s *Server) podsTop(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
|
||||
if v, ok := ctr.GetArguments()["label_selector"].(string); ok {
|
||||
podsTopOptions.LabelSelector = v
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).PodsTop(ctx, podsTopOptions)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.PodsTop(ctx, podsTopOptions)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to get pods top: %v", err)), nil
|
||||
}
|
||||
@@ -238,7 +258,11 @@ func (s *Server) podsExec(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Ca
|
||||
} else {
|
||||
return NewTextResult("", errors.New("failed to exec in pod, invalid command argument")), nil
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).PodsExec(ctx, ns.(string), name.(string), container.(string), command)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.PodsExec(ctx, ns.(string), name.(string), container.(string), command)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to exec in pod %s in namespace %s: %v", name, ns, err)), nil
|
||||
} else if ret == "" {
|
||||
@@ -260,7 +284,11 @@ func (s *Server) podsLog(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
|
||||
if container == nil {
|
||||
container = ""
|
||||
}
|
||||
ret, err := s.k.Derived(ctx).PodsLog(ctx, ns.(string), name.(string), container.(string))
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.PodsLog(ctx, ns.(string), name.(string), container.(string))
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to get pod %s log in namespace %s: %v", name, ns, err)), nil
|
||||
} else if ret == "" {
|
||||
@@ -286,7 +314,11 @@ func (s *Server) podsRun(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
|
||||
if port == nil {
|
||||
port = float64(0)
|
||||
}
|
||||
resources, err := s.k.Derived(ctx).PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resources, err := derived.PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to run pod %s in namespace %s: %v", name, ns, err)), nil
|
||||
}
|
||||
|
||||
@@ -128,7 +128,11 @@ func (s *Server) resourcesList(ctx context.Context, ctr mcp.CallToolRequest) (*m
|
||||
return NewTextResult("", fmt.Errorf("namespace is not a string")), nil
|
||||
}
|
||||
|
||||
ret, err := s.k.Derived(ctx).ResourcesList(ctx, gvk, ns, resourceListOptions)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.ResourcesList(ctx, gvk, ns, resourceListOptions)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to list resources: %v", err)), nil
|
||||
}
|
||||
@@ -159,7 +163,11 @@ func (s *Server) resourcesGet(ctx context.Context, ctr mcp.CallToolRequest) (*mc
|
||||
return NewTextResult("", fmt.Errorf("name is not a string")), nil
|
||||
}
|
||||
|
||||
ret, err := s.k.Derived(ctx).ResourcesGet(ctx, gvk, ns, n)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := derived.ResourcesGet(ctx, gvk, ns, n)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to get resource: %v", err)), nil
|
||||
}
|
||||
@@ -177,7 +185,11 @@ func (s *Server) resourcesCreateOrUpdate(ctx context.Context, ctr mcp.CallToolRe
|
||||
return NewTextResult("", fmt.Errorf("resource is not a string")), nil
|
||||
}
|
||||
|
||||
resources, err := s.k.Derived(ctx).ResourcesCreateOrUpdate(ctx, r)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resources, err := derived.ResourcesCreateOrUpdate(ctx, r)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to create or update resources: %v", err)), nil
|
||||
}
|
||||
@@ -212,7 +224,11 @@ func (s *Server) resourcesDelete(ctx context.Context, ctr mcp.CallToolRequest) (
|
||||
return NewTextResult("", fmt.Errorf("name is not a string")), nil
|
||||
}
|
||||
|
||||
err = s.k.Derived(ctx).ResourcesDelete(ctx, gvk, ns, n)
|
||||
derived, err := s.k.Derived(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = derived.ResourcesDelete(ctx, gvk, ns, n)
|
||||
if err != nil {
|
||||
return NewTextResult("", fmt.Errorf("failed to delete resource: %v", err)), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user