feat(auth): authorize user from custom SSE header (96)

feat(auth): Authorize user from custom SSE header

PoC to show how we can propagate an Authorization Bearer token
from the MCP client up to the Kubernetes API by passing a custom
header (Kubernetes-Authorization-Bearer-Token).

A new Derived client is necessary for each request due to the incompleteness
of some of the client-go clients.
This might add some overhead for each prompt.
Ideally, the issue with the discoveryclient and others should be fixed to
allow reading the authorization header from the request context.

To use the feature, the MCP Server still needs to be started with a basic
configuration (either provided InCluster by a service account or locally by
 a .kube/config file) so that it's able to infer the server settings.
---
test(auth): added tests to verify header propagation
---
refactor(auth): minor improvements for derived client
This commit is contained in:
Marc Nuri
2025-05-29 17:07:28 +02:00
committed by GitHub
parent 9830e2249d
commit f80d8df3c4
10 changed files with 182 additions and 22 deletions

View File

@@ -0,0 +1,15 @@
package kubernetes
import "net/http"
type impersonateRoundTripper struct {
delegate http.RoundTripper
}
func (irt *impersonateRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: Solution won't work with discoveryclient which uses context.TODO() instead of the passed-in context
if v, ok := req.Context().Value(AuthorizationHeader).(string); ok {
req.Header.Set("Authorization", v)
}
return irt.delegate.RoundTrip(req)
}

View File

@@ -1,6 +1,7 @@
package kubernetes
import (
"context"
"github.com/fsnotify/fsnotify"
"github.com/manusa/kubernetes-mcp-server/pkg/helm"
v1 "k8s.io/api/core/v1"
@@ -15,9 +16,15 @@ import (
"k8s.io/client-go/rest"
"k8s.io/client-go/restmapper"
"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"sigs.k8s.io/yaml"
)
const (
AuthorizationHeader = "Kubernetes-Authorization"
AuthorizationBearerTokenHeader = "kubernetes-authorization-bearer-token"
)
type CloseWatchKubeConfig func() error
type Kubernetes struct {
@@ -42,6 +49,10 @@ func NewKubernetes(kubeconfig string) (*Kubernetes, error) {
if err := resolveKubernetesConfigurations(k8s); err != nil {
return nil, err
}
// TODO: Won't work because not all client-go clients use the shared context (e.g. discovery client uses context.TODO())
//k8s.cfg.Wrap(func(original http.RoundTripper) http.RoundTripper {
// return &impersonateRoundTripper{original}
//})
var err error
k8s.clientSet, err = kubernetes.NewForConfig(k8s.cfg)
if err != nil {
@@ -52,7 +63,7 @@ func NewKubernetes(kubeconfig string) (*Kubernetes, error) {
return nil, err
}
k8s.discoveryClient = memory.NewMemCacheClient(discoveryClient)
k8s.deferredDiscoveryRESTMapper = restmapper.NewDeferredDiscoveryRESTMapper(memory.NewMemCacheClient(k8s.discoveryClient))
k8s.deferredDiscoveryRESTMapper = restmapper.NewDeferredDiscoveryRESTMapper(k8s.discoveryClient)
k8s.dynamicClient, err = dynamic.NewForConfig(k8s.cfg)
if err != nil {
return nil, err
@@ -116,6 +127,50 @@ func (k *Kubernetes) ToRESTMapper() (meta.RESTMapper, error) {
return k.deferredDiscoveryRESTMapper, nil
}
func (k *Kubernetes) Derived(ctx context.Context) *Kubernetes {
bearerToken, ok := ctx.Value(AuthorizationBearerTokenHeader).(string)
if !ok {
return k
}
derivedCfg := rest.CopyConfig(k.cfg)
derivedCfg.BearerToken = bearerToken
derivedCfg.BearerTokenFile = ""
derivedCfg.Username = ""
derivedCfg.Password = ""
derivedCfg.AuthProvider = nil
derivedCfg.AuthConfigPersister = nil
derivedCfg.ExecProvider = nil
derivedCfg.Impersonate = rest.ImpersonationConfig{}
clientCmdApiConfig, err := k.clientCmdConfig.RawConfig()
if err != nil {
return k
}
clientCmdApiConfig.AuthInfos = make(map[string]*clientcmdapi.AuthInfo)
derived := &Kubernetes{
Kubeconfig: k.Kubeconfig,
clientCmdConfig: clientcmd.NewDefaultClientConfig(clientCmdApiConfig, nil),
cfg: derivedCfg,
scheme: k.scheme,
parameterCodec: k.parameterCodec,
}
derived.clientSet, err = kubernetes.NewForConfig(derived.cfg)
if err != nil {
return k
}
discoveryClient, err := discovery.NewDiscoveryClientForConfig(derived.cfg)
if err != nil {
return k
}
derived.discoveryClient = memory.NewMemCacheClient(discoveryClient)
derived.deferredDiscoveryRESTMapper = restmapper.NewDeferredDiscoveryRESTMapper(derived.discoveryClient)
derived.dynamicClient, err = dynamic.NewForConfig(derived.cfg)
if err != nil {
return k
}
derived.Helm = helm.NewHelm(derived)
return derived
}
func marshal(v any) (string, error) {
switch t := v.(type) {
case []unstructured.Unstructured:

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
@@ -97,6 +98,7 @@ type mcpContext struct {
profile Profile
readOnly bool
disableDestructive bool
clientOptions []transport.ClientOption
before func(*mcpContext)
after func(*mcpContext)
ctx context.Context
@@ -124,8 +126,8 @@ func (c *mcpContext) beforeEach(t *testing.T) {
t.Fatal(err)
return
}
c.mcpHttpServer = server.NewTestServer(c.mcpServer.server)
if c.mcpClient, err = client.NewSSEMCPClient(c.mcpHttpServer.URL + "/sse"); err != nil {
c.mcpHttpServer = server.NewTestServer(c.mcpServer.server, server.WithSSEContextFunc(contextFunc))
if c.mcpClient, err = client.NewSSEMCPClient(c.mcpHttpServer.URL+"/sse", c.clientOptions...); err != nil {
t.Fatal(err)
return
}

View File

@@ -27,7 +27,7 @@ func (s *Server) eventsList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
if namespace == nil {
namespace = ""
}
ret, err := s.k.EventsList(ctx, namespace.(string))
ret, err := s.k.Derived(ctx).EventsList(ctx, namespace.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list events in all namespaces: %v", err)), nil
}

View File

@@ -64,14 +64,14 @@ func (s *Server) helmInstall(ctx context.Context, ctr mcp.CallToolRequest) (*mcp
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
namespace = v
}
ret, err := s.k.Helm.Install(ctx, chart, values, name, namespace)
ret, err := s.k.Derived(ctx).Helm.Install(ctx, chart, values, name, namespace)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to install helm chart '%s': %w", chart, err)), nil
}
return NewTextResult(ret, err), nil
}
func (s *Server) helmList(_ context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func (s *Server) helmList(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
allNamespaces := false
if v, ok := ctr.GetArguments()["all_namespaces"].(bool); ok {
allNamespaces = v
@@ -80,14 +80,14 @@ func (s *Server) helmList(_ context.Context, ctr mcp.CallToolRequest) (*mcp.Call
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
namespace = v
}
ret, err := s.k.Helm.List(namespace, allNamespaces)
ret, err := s.k.Derived(ctx).Helm.List(namespace, allNamespaces)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list helm releases in namespace '%s': %w", namespace, err)), nil
}
return NewTextResult(ret, err), nil
}
func (s *Server) helmUninstall(_ context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func (s *Server) helmUninstall(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var name string
ok := false
if name, ok = ctr.GetArguments()["name"].(string); !ok {
@@ -97,7 +97,7 @@ func (s *Server) helmUninstall(_ context.Context, ctr mcp.CallToolRequest) (*mcp
if v, ok := ctr.GetArguments()["namespace"].(string); ok {
namespace = v
}
ret, err := s.k.Helm.Uninstall(name, namespace)
ret, err := s.k.Derived(ctx).Helm.Uninstall(name, namespace)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to uninstall helm chart '%s': %w", name, err)), nil
}

View File

@@ -1,10 +1,12 @@
package mcp
import (
"context"
"github.com/manusa/kubernetes-mcp-server/pkg/kubernetes"
"github.com/manusa/kubernetes-mcp-server/pkg/version"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"net/http"
)
type Configuration struct {
@@ -71,6 +73,7 @@ func (s *Server) ServeStdio() error {
func (s *Server) ServeSse(baseUrl string) *server.SSEServer {
options := make([]server.SSEOption, 0)
options = append(options, server.WithSSEContextFunc(contextFunc))
if baseUrl != "" {
options = append(options, server.WithBaseURL(baseUrl))
}
@@ -104,3 +107,8 @@ func NewTextResult(content string, err error) *mcp.CallToolResult {
},
}
}
func contextFunc(ctx context.Context, r *http.Request) context.Context {
//return context.WithValue(ctx, kubernetes.AuthorizationHeader, r.Header.Get(kubernetes.AuthorizationHeader))
return context.WithValue(ctx, kubernetes.AuthorizationBearerTokenHeader, r.Header.Get(kubernetes.AuthorizationBearerTokenHeader))
}

View File

@@ -2,7 +2,9 @@ package mcp
import (
"context"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"net/http"
"os"
"path/filepath"
"runtime"
@@ -88,3 +90,81 @@ func TestDisableDestructive(t *testing.T) {
})
})
}
func TestSseHeaders(t *testing.T) {
mockServer := NewMockServer()
defer mockServer.Close()
before := func(c *mcpContext) {
c.withKubeConfig(mockServer.config)
c.clientOptions = append(c.clientOptions, client.WithHeaders(map[string]string{"kubernetes-authorization-bearer-token": "a-token-from-mcp-client"}))
}
pathHeaders := make(map[string]http.Header, 0)
mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
pathHeaders[req.URL.Path] = req.Header.Clone()
// Request Performed by DiscoveryClient to Kube API (Get API Groups legacy -core-)
if req.URL.Path == "/api" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"kind":"APIVersions","versions":["v1"],"serverAddressByClientCIDRs":[{"clientCIDR":"0.0.0.0/0"}]}`))
return
}
// Request Performed by DiscoveryClient to Kube API (Get API Groups)
if req.URL.Path == "/apis" {
w.Header().Set("Content-Type", "application/json")
//w.Write([]byte(`{"kind":"APIGroupList","apiVersion":"v1","groups":[{"name":"apps","versions":[{"groupVersion":"apps/v1","version":"v1"}],"preferredVersion":{"groupVersion":"apps/v1","version":"v1"}}]}`))
_, _ = w.Write([]byte(`{"kind":"APIGroupList","apiVersion":"v1","groups":[]}`))
return
}
// Request Performed by DiscoveryClient to Kube API (Get API Resources)
if req.URL.Path == "/api/v1" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"kind":"APIResourceList","apiVersion":"v1","resources":[{"name":"pods","singularName":"","namespaced":true,"kind":"Pod","verbs":["get","list","watch","create","update","patch","delete"]}]}`))
return
}
// Request Performed by DynamicClient
if req.URL.Path == "/api/v1/namespaces/default/pods" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"kind":"PodList","apiVersion":"v1","items":[]}`))
return
}
// Request Performed by kubernetes.Interface
if req.URL.Path == "/api/v1/namespaces/default/pods/a-pod-to-delete" {
w.WriteHeader(200)
return
}
w.WriteHeader(404)
}))
testCaseWithContext(t, &mcpContext{before: before}, func(c *mcpContext) {
c.callTool("pods_list", map[string]interface{}{})
t.Run("DiscoveryClient propagates headers to Kube API", func(t *testing.T) {
if len(pathHeaders) == 0 {
t.Fatalf("No requests were made to Kube API")
}
if pathHeaders["/api"] == nil || pathHeaders["/api"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
t.Fatalf("Overridden header Authorization not found in request to /api")
}
if pathHeaders["/apis"] == nil || pathHeaders["/apis"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
t.Fatalf("Overridden header Authorization not found in request to /apis")
}
if pathHeaders["/api/v1"] == nil || pathHeaders["/api/v1"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
t.Fatalf("Overridden header Authorization not found in request to /api/v1")
}
})
t.Run("DynamicClient propagates headers to Kube API", func(t *testing.T) {
if len(pathHeaders) == 0 {
t.Fatalf("No requests were made to Kube API")
}
if pathHeaders["/api/v1/namespaces/default/pods"] == nil || pathHeaders["/api/v1/namespaces/default/pods"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
t.Fatalf("Overridden header Authorization not found in request to /api/v1/namespaces/default/pods")
}
})
c.callTool("pods_delete", map[string]interface{}{"name": "a-pod-to-delete"})
t.Run("kubernetes.Interface propagates headers to Kube API", func(t *testing.T) {
if len(pathHeaders) == 0 {
t.Fatalf("No requests were made to Kube API")
}
if pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"] == nil || pathHeaders["/api/v1/namespaces/default/pods/a-pod-to-delete"].Get("Authorization") != "Bearer a-token-from-mcp-client" {
t.Fatalf("Overridden header Authorization not found in request to /api/v1/namespaces/default/pods/a-pod-to-delete")
}
})
})
}

View File

@@ -35,7 +35,7 @@ func (s *Server) initNamespaces() []server.ServerTool {
}
func (s *Server) namespacesList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
ret, err := s.k.NamespacesList(ctx)
ret, err := s.k.Derived(ctx).NamespacesList(ctx)
if err != nil {
err = fmt.Errorf("failed to list namespaces: %v", err)
}
@@ -43,7 +43,7 @@ func (s *Server) namespacesList(ctx context.Context, _ mcp.CallToolRequest) (*mc
}
func (s *Server) projectsList(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
ret, err := s.k.ProjectsList(ctx)
ret, err := s.k.Derived(ctx).ProjectsList(ctx)
if err != nil {
err = fmt.Errorf("failed to list projects: %v", err)
}

View File

@@ -110,7 +110,7 @@ func (s *Server) podsListInAllNamespaces(ctx context.Context, ctr mcp.CallToolRe
selector = labelSelector.(string)
}
ret, err := s.k.PodsListInAllNamespaces(ctx, selector)
ret, err := s.k.Derived(ctx).PodsListInAllNamespaces(ctx, selector)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list pods in all namespaces: %v", err)), nil
}
@@ -127,7 +127,7 @@ func (s *Server) podsListInNamespace(ctx context.Context, ctr mcp.CallToolReques
if labelSelector != nil {
selector = labelSelector.(string)
}
ret, err := s.k.PodsListInNamespace(ctx, ns.(string), selector)
ret, err := s.k.Derived(ctx).PodsListInNamespace(ctx, ns.(string), selector)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list pods in namespace %s: %v", ns, err)), nil
}
@@ -143,7 +143,7 @@ func (s *Server) podsGet(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
if name == nil {
return NewTextResult("", errors.New("failed to get pod, missing argument name")), nil
}
ret, err := s.k.PodsGet(ctx, ns.(string), name.(string))
ret, err := s.k.Derived(ctx).PodsGet(ctx, ns.(string), name.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get pod %s in namespace %s: %v", name, ns, err)), nil
}
@@ -159,7 +159,7 @@ func (s *Server) podsDelete(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.
if name == nil {
return NewTextResult("", errors.New("failed to delete pod, missing argument name")), nil
}
ret, err := s.k.PodsDelete(ctx, ns.(string), name.(string))
ret, err := s.k.Derived(ctx).PodsDelete(ctx, ns.(string), name.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to delete pod %s in namespace %s: %v", name, ns, err)), nil
}
@@ -190,7 +190,7 @@ func (s *Server) podsExec(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Ca
} else {
return NewTextResult("", errors.New("failed to exec in pod, invalid command argument")), nil
}
ret, err := s.k.PodsExec(ctx, ns.(string), name.(string), container.(string), command)
ret, err := s.k.Derived(ctx).PodsExec(ctx, ns.(string), name.(string), container.(string), command)
if err != nil {
return NewTextResult("", fmt.Errorf("failed to exec in pod %s in namespace %s: %v", name, ns, err)), nil
} else if ret == "" {
@@ -212,7 +212,7 @@ func (s *Server) podsLog(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
if container == nil {
container = ""
}
ret, err := s.k.PodsLog(ctx, ns.(string), name.(string), container.(string))
ret, err := s.k.Derived(ctx).PodsLog(ctx, ns.(string), name.(string), container.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get pod %s log in namespace %s: %v", name, ns, err)), nil
} else if ret == "" {
@@ -238,7 +238,7 @@ func (s *Server) podsRun(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.Cal
if port == nil {
port = float64(0)
}
ret, err := s.k.PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
ret, err := s.k.Derived(ctx).PodsRun(ctx, ns.(string), name.(string), image.(string), int32(port.(float64)))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get pod %s log in namespace %s: %v", name, ns, err)), nil
}

View File

@@ -111,7 +111,7 @@ func (s *Server) resourcesList(ctx context.Context, ctr mcp.CallToolRequest) (*m
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list resources, %s", err)), nil
}
ret, err := s.k.ResourcesList(ctx, gvk, namespace.(string), labelSelector.(string))
ret, err := s.k.Derived(ctx).ResourcesList(ctx, gvk, namespace.(string), labelSelector.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to list resources: %v", err)), nil
}
@@ -131,7 +131,7 @@ func (s *Server) resourcesGet(ctx context.Context, ctr mcp.CallToolRequest) (*mc
if name == nil {
return NewTextResult("", errors.New("failed to get resource, missing argument name")), nil
}
ret, err := s.k.ResourcesGet(ctx, gvk, namespace.(string), name.(string))
ret, err := s.k.Derived(ctx).ResourcesGet(ctx, gvk, namespace.(string), name.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to get resource: %v", err)), nil
}
@@ -143,7 +143,7 @@ func (s *Server) resourcesCreateOrUpdate(ctx context.Context, ctr mcp.CallToolRe
if resource == nil || resource == "" {
return NewTextResult("", errors.New("failed to create or update resources, missing argument resource")), nil
}
ret, err := s.k.ResourcesCreateOrUpdate(ctx, resource.(string))
ret, err := s.k.Derived(ctx).ResourcesCreateOrUpdate(ctx, resource.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to create or update resources: %v", err)), nil
}
@@ -163,7 +163,7 @@ func (s *Server) resourcesDelete(ctx context.Context, ctr mcp.CallToolRequest) (
if name == nil {
return NewTextResult("", errors.New("failed to delete resource, missing argument name")), nil
}
err = s.k.ResourcesDelete(ctx, gvk, namespace.(string), name.(string))
err = s.k.Derived(ctx).ResourcesDelete(ctx, gvk, namespace.(string), name.(string))
if err != nil {
return NewTextResult("", fmt.Errorf("failed to delete resource: %v", err)), nil
}