mirror of
https://github.com/containers/kubernetes-mcp-server.git
synced 2025-10-23 01:22:57 +03:00
feat(auth): support for VSCode auth flow (#258)
Adds DisableDynamicClientRegistration and OAuthScopes to be able to override the values proxied from the configured authorization server. DisableDynamicClientRegistration removes the registration_endpoint field from the well-known authorization resource metadata. This forces VSCode to show a for to input the Client ID and Client Secret since these can't be discovered. The OAuthScopes allows to override the scopes_supported field. VSCode automatically makes an auth request for all of the supported scopes. In many cases, this is not supported by the auth server. By providing this configuration, the user (MCP Server administrator) is able to set which scopes are effectively supported and force VSCode to only request these. Signed-off-by: Marc Nuri <marc@marcnuri.com>
This commit is contained in:
@@ -33,6 +33,11 @@ type StaticConfig struct {
|
||||
// AuthorizationURL is the URL of the OIDC authorization server.
|
||||
// It is used for token validation and for STS token exchange.
|
||||
AuthorizationURL string `toml:"authorization_url,omitempty"`
|
||||
// DisableDynamicClientRegistration indicates whether dynamic client registration is disabled.
|
||||
// If true, the .well-known endpoints will not expose the registration endpoint.
|
||||
DisableDynamicClientRegistration bool `toml:"disable_dynamic_client_registration,omitempty"`
|
||||
// OAuthScopes are the supported **client** scopes requested during the **client/frontend** OAuth flow.
|
||||
OAuthScopes []string `toml:"oauth_scopes,omitempty"`
|
||||
// StsClientId is the OAuth client ID used for backend token exchange
|
||||
StsClientId string `toml:"sts_client_id,omitempty"`
|
||||
// StsClientSecret is the OAuth client secret used for backend token exchange
|
||||
|
||||
@@ -111,6 +111,7 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi
|
||||
}
|
||||
// Token exchange with OIDC provider
|
||||
sts := NewFromConfig(staticConfig, oidcProvider)
|
||||
// TODO: Maybe the token had already been exchanged, if it has the right audience and scopes, we can skip this step.
|
||||
if err == nil && sts.IsEnabled() {
|
||||
var exchangedToken *oauth2.Token
|
||||
// If the token is valid, we can exchange it for a new token with the specified audience and scopes.
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/rsa"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -334,7 +335,28 @@ func TestWellKnownReverseProxy(t *testing.T) {
|
||||
})
|
||||
}
|
||||
})
|
||||
// With Authorization URL configured
|
||||
// With Authorization URL configured but invalid payload
|
||||
invalidPayloadServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`NOT A JSON PAYLOAD`))
|
||||
}))
|
||||
t.Cleanup(invalidPayloadServer.Close)
|
||||
invalidPayloadConfig := &config.StaticConfig{AuthorizationURL: invalidPayloadServer.URL, RequireOAuth: true, ValidateToken: true}
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: invalidPayloadConfig}, func(ctx *httpContext) {
|
||||
for _, path := range cases {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
t.Run("Protected resource '"+path+"' with invalid Authorization URL payload returns 500 - Internal Server Error", func(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get %s endpoint: %v", path, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Errorf("Expected HTTP 500 Internal Server Error, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
// With Authorization URL configured and valid payload
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") {
|
||||
http.NotFound(w, r)
|
||||
@@ -344,7 +366,8 @@ func TestWellKnownReverseProxy(t *testing.T) {
|
||||
_, _ = w.Write([]byte(`{"issuer": "https://example.com","scopes_supported":["mcp-server"]}`))
|
||||
}))
|
||||
t.Cleanup(testServer.Close)
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
|
||||
staticConfig := &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) {
|
||||
for _, path := range cases {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
@@ -365,6 +388,87 @@ func TestWellKnownReverseProxy(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestWellKnownOverrides(t *testing.T) {
|
||||
cases := []string{
|
||||
".well-known/oauth-authorization-server",
|
||||
".well-known/oauth-protected-resource",
|
||||
".well-known/openid-configuration",
|
||||
}
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`
|
||||
{
|
||||
"issuer": "https://localhost",
|
||||
"registration_endpoint": "https://localhost/clients-registrations/openid-connect",
|
||||
"require_request_uri_registration": true,
|
||||
"scopes_supported":["scope-1", "scope-2"]
|
||||
}`))
|
||||
}))
|
||||
t.Cleanup(testServer.Close)
|
||||
baseConfig := config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}
|
||||
// With Dynamic Client Registration disabled
|
||||
disableDynamicRegistrationConfig := baseConfig
|
||||
disableDynamicRegistrationConfig.DisableDynamicClientRegistration = true
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &disableDynamicRegistrationConfig}, func(ctx *httpContext) {
|
||||
for _, path := range cases {
|
||||
resp, _ := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
t.Run("DisableDynamicClientRegistration removes registration_endpoint field", func(t *testing.T) {
|
||||
if strings.Contains(string(body), "registration_endpoint") {
|
||||
t.Error("Expected registration_endpoint to be removed, but it was found in the response")
|
||||
}
|
||||
})
|
||||
t.Run("DisableDynamicClientRegistration sets require_request_uri_registration = false", func(t *testing.T) {
|
||||
if !strings.Contains(string(body), `"require_request_uri_registration":false`) {
|
||||
t.Error("Expected require_request_uri_registration to be false, but it was not found in the response")
|
||||
}
|
||||
})
|
||||
t.Run("DisableDynamicClientRegistration includes/preserves scopes_supported", func(t *testing.T) {
|
||||
if !strings.Contains(string(body), `"scopes_supported":["scope-1","scope-2"]`) {
|
||||
t.Error("Expected scopes_supported to be present, but it was not found in the response")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
// With overrides for OAuth scopes (client/frontend)
|
||||
oAuthScopesConfig := baseConfig
|
||||
oAuthScopesConfig.OAuthScopes = []string{"openid", "mcp-server"}
|
||||
testCaseWithContext(t, &httpContext{StaticConfig: &oAuthScopesConfig}, func(ctx *httpContext) {
|
||||
for _, path := range cases {
|
||||
resp, _ := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
t.Run("OAuthScopes overrides scopes_supported", func(t *testing.T) {
|
||||
if !strings.Contains(string(body), `"scopes_supported":["openid","mcp-server"]`) {
|
||||
t.Errorf("Expected scopes_supported to be overridden, but original was preserved, response: %s", string(body))
|
||||
}
|
||||
})
|
||||
t.Run("OAuthScopes preserves other fields", func(t *testing.T) {
|
||||
if !strings.Contains(string(body), `"issuer":"https://localhost"`) {
|
||||
t.Errorf("Expected issuer to be preserved, but got: %s", string(body))
|
||||
}
|
||||
if !strings.Contains(string(body), `"registration_endpoint":"https://localhost`) {
|
||||
t.Errorf("Expected registration_endpoint to be preserved, but got: %s", string(body))
|
||||
}
|
||||
if !strings.Contains(string(body), `"require_request_uri_registration":true`) {
|
||||
t.Error("Expected require_request_uri_registration to be true, but it was not found in the response")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareLogging(t *testing.T) {
|
||||
testCase(t, func(ctx *httpContext) {
|
||||
_, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress))
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"io"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -21,7 +22,9 @@ var WellKnownEndpoints = []string{
|
||||
}
|
||||
|
||||
type WellKnown struct {
|
||||
authorizationUrl string
|
||||
authorizationUrl string
|
||||
scopesSupported []string
|
||||
disableDynamicClientRegistration bool
|
||||
}
|
||||
|
||||
var _ http.Handler = &WellKnown{}
|
||||
@@ -31,7 +34,11 @@ func WellKnownHandler(staticConfig *config.StaticConfig) http.Handler {
|
||||
if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") {
|
||||
authorizationUrl = strings.TrimSuffix(authorizationUrl, "/")
|
||||
}
|
||||
return &WellKnown{authorizationUrl}
|
||||
return &WellKnown{
|
||||
authorizationUrl: authorizationUrl,
|
||||
disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration,
|
||||
scopesSupported: staticConfig.OAuthScopes,
|
||||
}
|
||||
}
|
||||
|
||||
func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
@@ -50,16 +57,30 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
var resourceMetadata map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&resourceMetadata)
|
||||
if err != nil {
|
||||
http.Error(writer, "Failed to read response body: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if w.disableDynamicClientRegistration {
|
||||
delete(resourceMetadata, "registration_endpoint")
|
||||
resourceMetadata["require_request_uri_registration"] = false
|
||||
}
|
||||
if len(w.scopesSupported) > 0 {
|
||||
resourceMetadata["scopes_supported"] = w.scopesSupported
|
||||
}
|
||||
body, err := json.Marshal(resourceMetadata)
|
||||
if err != nil {
|
||||
http.Error(writer, "Failed to marshal response body: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
writer.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = writer.Write(body)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user