feat(auth): token exchange auth workflow (#255)

Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
Marc Nuri
2025-08-08 15:30:33 +03:00
committed by GitHub
parent 58c47dc95c
commit 90d4bb03f3
5 changed files with 188 additions and 27 deletions

View File

@@ -22,18 +22,27 @@ type StaticConfig struct {
DisableDestructive bool `toml:"disable_destructive,omitempty"`
EnabledTools []string `toml:"enabled_tools,omitempty"`
DisabledTools []string `toml:"disabled_tools,omitempty"`
RequireOAuth bool `toml:"require_oauth,omitempty"`
//Authorization related fields
// Authorization-related fields
// RequireOAuth indicates whether the server requires OAuth for authentication.
RequireOAuth bool `toml:"require_oauth,omitempty"`
// OAuthAudience is the valid audience for the OAuth tokens, used for offline JWT claim validation.
OAuthAudience string `toml:"oauth_audience,omitempty"`
// ValidateToken indicates whether the server should validate the token against the Kubernetes API Server using TokenReview.
ValidateToken bool `toml:"validate_token,omitempty"`
// AuthorizationURL is the URL of the OIDC authorization server.
// It is used for token validation and for STS token exchange.
AuthorizationURL string `toml:"authorization_url,omitempty"`
CertificateAuthority string `toml:"certificate_authority,omitempty"`
ServerURL string `toml:"server_url,omitempty"`
AuthorizationURL string `toml:"authorization_url,omitempty"`
// StsClientId is the OAuth client ID used for backend token exchange
StsClientId string `toml:"sts_client_id,omitempty"`
// StsClientSecret is the OAuth client secret used for backend token exchange
StsClientSecret string `toml:"sts_client_secret,omitempty"`
// StsAudience is the audience for the STS token exchange.
StsAudience string `toml:"sts_audience,omitempty"`
// StsScopes is the scopes for the STS token exchange.
StsScopes []string `toml:"sts_scopes,omitempty"`
CertificateAuthority string `toml:"certificate_authority,omitempty"`
ServerURL string `toml:"server_url,omitempty"`
}
type GroupVersionKind struct {

View File

@@ -6,15 +6,16 @@ import (
"net/http"
"strings"
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"golang.org/x/oauth2"
authenticationapiv1 "k8s.io/api/authentication/v1"
"k8s.io/klog/v2"
"k8s.io/utils/strings/slices"
"github.com/containers/kubernetes-mcp-server/pkg/config"
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
)
type KubernetesApiTokenVerifier interface {
@@ -26,7 +27,7 @@ type KubernetesApiTokenVerifier interface {
//
// The flow is skipped for unprotected resources, such as health checks and well-known endpoints.
//
// There are several auth scenarios
// There are several auth scenarios supported by this middleware:
//
// 1. requireOAuth is false:
//
@@ -42,13 +43,25 @@ type KubernetesApiTokenVerifier interface {
// - If OAuthAudience is set, the token is validated against the audience.
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
//
// see TestAuthorizationRawToken
//
// 2.2. OIDC Provider Validation (oidcProvider is not nil):
// - The token is validated offline for basic sanity checks (audience and expiration).
// - If OAuthAudience is set, the token is validated against the audience.
// - The token is then validated against the OIDC Provider.
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
//
// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx):
// see TestAuthorizationOidcToken
//
// 2.3. OIDC Token Exchange (oidcProvider is not nil, StsClientId and StsAudience are set):
// - The token is validated offline for basic sanity checks (audience and expiration).
// - If OAuthAudience is set, the token is validated against the audience.
// - The token is then validated against the OIDC Provider.
// - If the token is valid, an external account token exchange is performed using
// the OIDC Provider to obtain a new token with the specified audience and scopes.
// - If ValidateToken is set, the exchanged token is then used against the Kubernetes API Server for TokenReview.
//
// see TestAuthorizationOidcTokenExchange
func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -96,6 +109,22 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
r = r.WithContext(context.WithValue(r.Context(), mcp.TokenScopesContextKey, scopes))
}
// Token exchange with OIDC provider
sts := NewFromConfig(staticConfig, oidcProvider)
if err == nil && sts.IsEnabled() {
var exchangedToken *oauth2.Token
// If the token is valid, we can exchange it for a new token with the specified audience and scopes.
exchangedToken, err = sts.ExternalAccountTokenExchange(r.Context(), &oauth2.Token{
AccessToken: claims.Token,
TokenType: "Bearer",
})
if err == nil {
// Replace the original token with the exchanged token
token = exchangedToken.AccessToken
claims, err = ParseJWTClaims(token)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) // TODO: Implement test to verify, THIS IS A CRITICAL PART
}
}
// Kubernetes API Server TokenReview validation
if err == nil && staticConfig.ValidateToken {
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, verifier)

View File

@@ -132,9 +132,18 @@ func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpCo
test(httpCtx)
}
func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *oidc.Provider, httpServer *httptest.Server) {
type OidcTestServer struct {
*rsa.PrivateKey
*oidc.Provider
*httptest.Server
TokenEndpointHandler http.HandlerFunc
}
func NewOidcTestServer(t *testing.T) (oidcTestServer *OidcTestServer) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
var err error
oidcTestServer = &OidcTestServer{}
oidcTestServer.PrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate private key for oidc: %v", err)
}
@@ -142,15 +151,21 @@ func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *
Algorithms: []string{oidc.RS256, oidc.ES256},
PublicKeys: []oidctest.PublicKey{
{
PublicKey: privateKey.Public(),
PublicKey: oidcTestServer.Public(),
KeyID: "test-oidc-key-id",
Algorithm: oidc.RS256,
},
},
}
httpServer = httptest.NewServer(oidcServer)
oidcServer.SetIssuer(httpServer.URL)
oidcProvider, err = oidc.NewProvider(t.Context(), httpServer.URL)
oidcTestServer.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" && oidcTestServer.TokenEndpointHandler != nil {
oidcTestServer.TokenEndpointHandler.ServeHTTP(w, r)
return
}
oidcServer.ServeHTTP(w, r)
}))
oidcServer.SetIssuer(oidcTestServer.URL)
oidcTestServer.Provider, err = oidc.NewProvider(t.Context(), oidcTestServer.URL)
if err != nil {
t.Fatalf("failed to create OIDC provider: %v", err)
}
@@ -520,9 +535,9 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
})
// Failed OIDC validation
key, oidcProvider, httpServer := NewOidcTestServer(t)
t.Cleanup(httpServer.Close)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
oidcTestServer := NewOidcTestServer(t)
t.Cleanup(oidcTestServer.Close)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
@@ -554,12 +569,12 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
// Failed Kubernetes TokenReview
rawClaims := `{
"iss": "` + httpServer.URL + `",
"iss": "` + oidcTestServer.URL + `",
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
"aud": "mcp-server"
}`
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
@@ -591,7 +606,6 @@ func TestAuthorizationUnauthorized(t *testing.T) {
})
}
// TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required.
func TestAuthorizationRequireOAuthFalse(t *testing.T) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) {
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
@@ -657,17 +671,17 @@ func TestAuthorizationRawToken(t *testing.T) {
}
func TestAuthorizationOidcToken(t *testing.T) {
key, oidcProvider, httpServer := NewOidcTestServer(t)
t.Cleanup(httpServer.Close)
oidcTestServer := NewOidcTestServer(t)
t.Cleanup(oidcTestServer.Close)
rawClaims := `{
"iss": "` + httpServer.URL + `",
"iss": "` + oidcTestServer.URL + `",
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
"aud": "mcp-server"
}`
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
cases := []bool{false, true}
for _, validateToken := range cases {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
tokenReviewed := false
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
@@ -701,6 +715,69 @@ func TestAuthorizationOidcToken(t *testing.T) {
}
})
})
}
}
func TestAuthorizationOidcTokenExchange(t *testing.T) {
oidcTestServer := NewOidcTestServer(t)
t.Cleanup(oidcTestServer.Close)
rawClaims := `{
"iss": "` + oidcTestServer.URL + `",
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
"aud": "%s"
}`
validOidcClientToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256,
fmt.Sprintf(rawClaims, "mcp-server"))
validOidcBackendToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256,
fmt.Sprintf(rawClaims, "backend-audience"))
oidcTestServer.TokenEndpointHandler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = fmt.Fprintf(w, `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}`, validOidcBackendToken)
}
cases := []bool{false, true}
for _, validateToken := range cases {
staticConfig := &config.StaticConfig{
RequireOAuth: true,
OAuthAudience: "mcp-server",
ValidateToken: validateToken,
StsClientId: "test-sts-client-id",
StsClientSecret: "test-sts-client-secret",
StsAudience: "backend-audience",
StsScopes: []string{"backend-scope"},
}
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
tokenReviewed := false
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tokenReviewSuccessful))
tokenReviewed = true
return
}
}))
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+validOidcClientToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get protected endpoint: %v", err)
}
t.Cleanup(func() { _ = resp.Body.Close() })
t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header returns 200 - OK", validateToken), func(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
}
})
t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header performs token validation accordingly", validateToken), func(t *testing.T) {
if tokenReviewed == true && !validateToken {
t.Errorf("Expected token review to be skipped when validate-token is false, but it was performed")
}
if tokenReviewed == false && validateToken {
t.Errorf("Expected token review to be performed when validate-token is true, but it was skipped")
}
})
})
}
}

View File

@@ -6,6 +6,8 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/externalaccount"
"github.com/containers/kubernetes-mcp-server/pkg/config"
)
type staticSubjectTokenSupplier struct {
@@ -26,6 +28,20 @@ type SecurityTokenService struct {
ExternalAccountScopes []string
}
func NewFromConfig(config *config.StaticConfig, provider *oidc.Provider) *SecurityTokenService {
return &SecurityTokenService{
Provider: provider,
ClientId: config.StsClientId,
ClientSecret: config.StsClientSecret,
ExternalAccountAudience: config.StsAudience,
ExternalAccountScopes: config.StsScopes,
}
}
func (sts *SecurityTokenService) IsEnabled() bool {
return sts.Provider != nil && sts.ClientId != "" && sts.ExternalAccountAudience != ""
}
func (sts *SecurityTokenService) ExternalAccountTokenExchange(ctx context.Context, originalToken *oauth2.Token) (*oauth2.Token, error) {
ts, err := externalaccount.NewTokenSource(ctx, externalaccount.Config{
TokenURL: sts.Endpoint().TokenURL,

View File

@@ -12,6 +12,36 @@ import (
"golang.org/x/oauth2"
)
func TestIsEnabled(t *testing.T) {
disabledCases := []SecurityTokenService{
{},
{Provider: nil},
{Provider: &oidc.Provider{}},
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ClientSecret: "test-client-secret"},
{ClientId: "test-client-id", ClientSecret: "test-client-secret", ExternalAccountAudience: "test-audience"},
{Provider: &oidc.Provider{}, ClientSecret: "test-client-secret", ExternalAccountAudience: "test-audience"},
}
for _, sts := range disabledCases {
t.Run(fmt.Sprintf("SecurityTokenService{%+v}.IsEnabled() = false", sts), func(t *testing.T) {
if sts.IsEnabled() {
t.Errorf("SecurityTokenService{%+v}.IsEnabled() = true; want false", sts)
}
})
}
enabledCases := []SecurityTokenService{
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience"},
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience", ClientSecret: "test-client-secret"},
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience", ClientSecret: "test-client-secret", ExternalAccountScopes: []string{"test-scope"}},
}
for _, sts := range enabledCases {
t.Run(fmt.Sprintf("SecurityTokenService{%+v}.IsEnabled() = true", sts), func(t *testing.T) {
if !sts.IsEnabled() {
t.Errorf("SecurityTokenService{%+v}.IsEnabled() = false; want true", sts)
}
})
}
}
func TestExternalAccountTokenExchange(t *testing.T) {
mockServer := test.NewMockServer()
authServer := mockServer.Config().Host