mirror of
https://github.com/openshift/openshift-mcp-server.git
synced 2025-10-17 14:27:48 +03:00
feat(auth): token exchange auth workflow (#255)
Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
@@ -22,18 +22,27 @@ type StaticConfig struct {
|
||||
DisableDestructive bool `toml:"disable_destructive,omitempty"`
|
||||
EnabledTools []string `toml:"enabled_tools,omitempty"`
|
||||
DisabledTools []string `toml:"disabled_tools,omitempty"`
|
||||
RequireOAuth bool `toml:"require_oauth,omitempty"`
|
||||
|
||||
//Authorization related fields
|
||||
// Authorization-related fields
|
||||
// RequireOAuth indicates whether the server requires OAuth for authentication.
|
||||
RequireOAuth bool `toml:"require_oauth,omitempty"`
|
||||
// OAuthAudience is the valid audience for the OAuth tokens, used for offline JWT claim validation.
|
||||
OAuthAudience string `toml:"oauth_audience,omitempty"`
|
||||
// ValidateToken indicates whether the server should validate the token against the Kubernetes API Server using TokenReview.
|
||||
ValidateToken bool `toml:"validate_token,omitempty"`
|
||||
// AuthorizationURL is the URL of the OIDC authorization server.
|
||||
// It is used for token validation and for STS token exchange.
|
||||
AuthorizationURL string `toml:"authorization_url,omitempty"`
|
||||
CertificateAuthority string `toml:"certificate_authority,omitempty"`
|
||||
ServerURL string `toml:"server_url,omitempty"`
|
||||
AuthorizationURL string `toml:"authorization_url,omitempty"`
|
||||
// StsClientId is the OAuth client ID used for backend token exchange
|
||||
StsClientId string `toml:"sts_client_id,omitempty"`
|
||||
// StsClientSecret is the OAuth client secret used for backend token exchange
|
||||
StsClientSecret string `toml:"sts_client_secret,omitempty"`
|
||||
// StsAudience is the audience for the STS token exchange.
|
||||
StsAudience string `toml:"sts_audience,omitempty"`
|
||||
// StsScopes is the scopes for the STS token exchange.
|
||||
StsScopes []string `toml:"sts_scopes,omitempty"`
|
||||
CertificateAuthority string `toml:"certificate_authority,omitempty"`
|
||||
ServerURL string `toml:"server_url,omitempty"`
|
||||
}
|
||||
|
||||
type GroupVersionKind struct {
|
||||
|
||||
@@ -6,15 +6,16 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
authenticationapiv1 "k8s.io/api/authentication/v1"
|
||||
"k8s.io/klog/v2"
|
||||
"k8s.io/utils/strings/slices"
|
||||
|
||||
"github.com/containers/kubernetes-mcp-server/pkg/config"
|
||||
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
|
||||
)
|
||||
|
||||
type KubernetesApiTokenVerifier interface {
|
||||
@@ -26,7 +27,7 @@ type KubernetesApiTokenVerifier interface {
|
||||
//
|
||||
// The flow is skipped for unprotected resources, such as health checks and well-known endpoints.
|
||||
//
|
||||
// There are several auth scenarios
|
||||
// There are several auth scenarios supported by this middleware:
|
||||
//
|
||||
// 1. requireOAuth is false:
|
||||
//
|
||||
@@ -42,13 +43,25 @@ type KubernetesApiTokenVerifier interface {
|
||||
// - If OAuthAudience is set, the token is validated against the audience.
|
||||
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
|
||||
//
|
||||
// see TestAuthorizationRawToken
|
||||
//
|
||||
// 2.2. OIDC Provider Validation (oidcProvider is not nil):
|
||||
// - The token is validated offline for basic sanity checks (audience and expiration).
|
||||
// - If OAuthAudience is set, the token is validated against the audience.
|
||||
// - The token is then validated against the OIDC Provider.
|
||||
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
|
||||
//
|
||||
// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx):
|
||||
// see TestAuthorizationOidcToken
|
||||
//
|
||||
// 2.3. OIDC Token Exchange (oidcProvider is not nil, StsClientId and StsAudience are set):
|
||||
// - The token is validated offline for basic sanity checks (audience and expiration).
|
||||
// - If OAuthAudience is set, the token is validated against the audience.
|
||||
// - The token is then validated against the OIDC Provider.
|
||||
// - If the token is valid, an external account token exchange is performed using
|
||||
// the OIDC Provider to obtain a new token with the specified audience and scopes.
|
||||
// - If ValidateToken is set, the exchanged token is then used against the Kubernetes API Server for TokenReview.
|
||||
//
|
||||
// see TestAuthorizationOidcTokenExchange
|
||||
func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -96,6 +109,22 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
|
||||
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
|
||||
r = r.WithContext(context.WithValue(r.Context(), mcp.TokenScopesContextKey, scopes))
|
||||
}
|
||||
// Token exchange with OIDC provider
|
||||
sts := NewFromConfig(staticConfig, oidcProvider)
|
||||
if err == nil && sts.IsEnabled() {
|
||||
var exchangedToken *oauth2.Token
|
||||
// If the token is valid, we can exchange it for a new token with the specified audience and scopes.
|
||||
exchangedToken, err = sts.ExternalAccountTokenExchange(r.Context(), &oauth2.Token{
|
||||
AccessToken: claims.Token,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
if err == nil {
|
||||
// Replace the original token with the exchanged token
|
||||
token = exchangedToken.AccessToken
|
||||
claims, err = ParseJWTClaims(token)
|
||||
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) // TODO: Implement test to verify, THIS IS A CRITICAL PART
|
||||
}
|
||||
}
|
||||
// Kubernetes API Server TokenReview validation
|
||||
if err == nil && staticConfig.ValidateToken {
|
||||
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, verifier)
|
||||
|
||||
@@ -132,9 +132,18 @@ func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpCo
|
||||
test(httpCtx)
|
||||
}
|
||||
|
||||
func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *oidc.Provider, httpServer *httptest.Server) {
|
||||
type OidcTestServer struct {
|
||||
*rsa.PrivateKey
|
||||
*oidc.Provider
|
||||
*httptest.Server
|
||||
TokenEndpointHandler http.HandlerFunc
|
||||
}
|
||||
|
||||
func NewOidcTestServer(t *testing.T) (oidcTestServer *OidcTestServer) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
var err error
|
||||
oidcTestServer = &OidcTestServer{}
|
||||
oidcTestServer.PrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key for oidc: %v", err)
|
||||
}
|
||||
@@ -142,15 +151,21 @@ func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *
|
||||
Algorithms: []string{oidc.RS256, oidc.ES256},
|
||||
PublicKeys: []oidctest.PublicKey{
|
||||
{
|
||||
PublicKey: privateKey.Public(),
|
||||
PublicKey: oidcTestServer.Public(),
|
||||
KeyID: "test-oidc-key-id",
|
||||
Algorithm: oidc.RS256,
|
||||
},
|
||||
},
|
||||
}
|
||||
httpServer = httptest.NewServer(oidcServer)
|
||||
oidcServer.SetIssuer(httpServer.URL)
|
||||
oidcProvider, err = oidc.NewProvider(t.Context(), httpServer.URL)
|
||||
oidcTestServer.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/token" && oidcTestServer.TokenEndpointHandler != nil {
|
||||
oidcTestServer.TokenEndpointHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
oidcServer.ServeHTTP(w, r)
|
||||
}))
|
||||
oidcServer.SetIssuer(oidcTestServer.URL)
|
||||
oidcTestServer.Provider, err = oidc.NewProvider(t.Context(), oidcTestServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create OIDC provider: %v", err)
|
||||
}
|
||||
@@ -520,9 +535,9 @@ func TestAuthorizationUnauthorized(t *testing.T) {
|
||||
})
|
||||
})
|
||||
// Failed OIDC validation
|
||||
key, oidcProvider, httpServer := NewOidcTestServer(t)
|
||||
t.Cleanup(httpServer.Close)
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
|
||||
oidcTestServer := NewOidcTestServer(t)
|
||||
t.Cleanup(oidcTestServer.Close)
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
@@ -554,12 +569,12 @@ func TestAuthorizationUnauthorized(t *testing.T) {
|
||||
})
|
||||
// Failed Kubernetes TokenReview
|
||||
rawClaims := `{
|
||||
"iss": "` + httpServer.URL + `",
|
||||
"iss": "` + oidcTestServer.URL + `",
|
||||
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
|
||||
"aud": "mcp-server"
|
||||
}`
|
||||
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
|
||||
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
@@ -591,7 +606,6 @@ func TestAuthorizationUnauthorized(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required.
|
||||
func TestAuthorizationRequireOAuthFalse(t *testing.T) {
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: false}}, func(ctx *httpContext) {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
|
||||
@@ -657,17 +671,17 @@ func TestAuthorizationRawToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthorizationOidcToken(t *testing.T) {
|
||||
key, oidcProvider, httpServer := NewOidcTestServer(t)
|
||||
t.Cleanup(httpServer.Close)
|
||||
oidcTestServer := NewOidcTestServer(t)
|
||||
t.Cleanup(oidcTestServer.Close)
|
||||
rawClaims := `{
|
||||
"iss": "` + httpServer.URL + `",
|
||||
"iss": "` + oidcTestServer.URL + `",
|
||||
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
|
||||
"aud": "mcp-server"
|
||||
}`
|
||||
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
|
||||
validOidcToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256, rawClaims)
|
||||
cases := []bool{false, true}
|
||||
for _, validateToken := range cases {
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
|
||||
tokenReviewed := false
|
||||
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
|
||||
@@ -701,6 +715,69 @@ func TestAuthorizationOidcToken(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationOidcTokenExchange(t *testing.T) {
|
||||
oidcTestServer := NewOidcTestServer(t)
|
||||
t.Cleanup(oidcTestServer.Close)
|
||||
rawClaims := `{
|
||||
"iss": "` + oidcTestServer.URL + `",
|
||||
"exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
|
||||
"aud": "%s"
|
||||
}`
|
||||
validOidcClientToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256,
|
||||
fmt.Sprintf(rawClaims, "mcp-server"))
|
||||
validOidcBackendToken := oidctest.SignIDToken(oidcTestServer.PrivateKey, "test-oidc-key-id", oidc.RS256,
|
||||
fmt.Sprintf(rawClaims, "backend-audience"))
|
||||
oidcTestServer.TokenEndpointHandler = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = fmt.Fprintf(w, `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}`, validOidcBackendToken)
|
||||
}
|
||||
cases := []bool{false, true}
|
||||
for _, validateToken := range cases {
|
||||
staticConfig := &config.StaticConfig{
|
||||
RequireOAuth: true,
|
||||
OAuthAudience: "mcp-server",
|
||||
ValidateToken: validateToken,
|
||||
StsClientId: "test-sts-client-id",
|
||||
StsClientSecret: "test-sts-client-secret",
|
||||
StsAudience: "backend-audience",
|
||||
StsScopes: []string{"backend-scope"},
|
||||
}
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig, OidcProvider: oidcTestServer.Provider}, func(ctx *httpContext) {
|
||||
tokenReviewed := false
|
||||
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(tokenReviewSuccessful))
|
||||
tokenReviewed = true
|
||||
return
|
||||
}
|
||||
}))
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+validOidcClientToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get protected endpoint: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header returns 200 - OK", validateToken), func(t *testing.T) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header performs token validation accordingly", validateToken), func(t *testing.T) {
|
||||
if tokenReviewed == true && !validateToken {
|
||||
t.Errorf("Expected token review to be skipped when validate-token is false, but it was performed")
|
||||
}
|
||||
if tokenReviewed == false && validateToken {
|
||||
t.Errorf("Expected token review to be performed when validate-token is true, but it was skipped")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google/externalaccount"
|
||||
|
||||
"github.com/containers/kubernetes-mcp-server/pkg/config"
|
||||
)
|
||||
|
||||
type staticSubjectTokenSupplier struct {
|
||||
@@ -26,6 +28,20 @@ type SecurityTokenService struct {
|
||||
ExternalAccountScopes []string
|
||||
}
|
||||
|
||||
func NewFromConfig(config *config.StaticConfig, provider *oidc.Provider) *SecurityTokenService {
|
||||
return &SecurityTokenService{
|
||||
Provider: provider,
|
||||
ClientId: config.StsClientId,
|
||||
ClientSecret: config.StsClientSecret,
|
||||
ExternalAccountAudience: config.StsAudience,
|
||||
ExternalAccountScopes: config.StsScopes,
|
||||
}
|
||||
}
|
||||
|
||||
func (sts *SecurityTokenService) IsEnabled() bool {
|
||||
return sts.Provider != nil && sts.ClientId != "" && sts.ExternalAccountAudience != ""
|
||||
}
|
||||
|
||||
func (sts *SecurityTokenService) ExternalAccountTokenExchange(ctx context.Context, originalToken *oauth2.Token) (*oauth2.Token, error) {
|
||||
ts, err := externalaccount.NewTokenSource(ctx, externalaccount.Config{
|
||||
TokenURL: sts.Endpoint().TokenURL,
|
||||
|
||||
@@ -12,6 +12,36 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestIsEnabled(t *testing.T) {
|
||||
disabledCases := []SecurityTokenService{
|
||||
{},
|
||||
{Provider: nil},
|
||||
{Provider: &oidc.Provider{}},
|
||||
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ClientSecret: "test-client-secret"},
|
||||
{ClientId: "test-client-id", ClientSecret: "test-client-secret", ExternalAccountAudience: "test-audience"},
|
||||
{Provider: &oidc.Provider{}, ClientSecret: "test-client-secret", ExternalAccountAudience: "test-audience"},
|
||||
}
|
||||
for _, sts := range disabledCases {
|
||||
t.Run(fmt.Sprintf("SecurityTokenService{%+v}.IsEnabled() = false", sts), func(t *testing.T) {
|
||||
if sts.IsEnabled() {
|
||||
t.Errorf("SecurityTokenService{%+v}.IsEnabled() = true; want false", sts)
|
||||
}
|
||||
})
|
||||
}
|
||||
enabledCases := []SecurityTokenService{
|
||||
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience"},
|
||||
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience", ClientSecret: "test-client-secret"},
|
||||
{Provider: &oidc.Provider{}, ClientId: "test-client-id", ExternalAccountAudience: "test-audience", ClientSecret: "test-client-secret", ExternalAccountScopes: []string{"test-scope"}},
|
||||
}
|
||||
for _, sts := range enabledCases {
|
||||
t.Run(fmt.Sprintf("SecurityTokenService{%+v}.IsEnabled() = true", sts), func(t *testing.T) {
|
||||
if !sts.IsEnabled() {
|
||||
t.Errorf("SecurityTokenService{%+v}.IsEnabled() = false; want true", sts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAccountTokenExchange(t *testing.T) {
|
||||
mockServer := test.NewMockServer()
|
||||
authServer := mockServer.Config().Host
|
||||
|
||||
Reference in New Issue
Block a user