diff --git a/pkg/config/config.go b/pkg/config/config.go index 6e79797..26e007d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -33,6 +33,11 @@ type StaticConfig struct { // AuthorizationURL is the URL of the OIDC authorization server. // It is used for token validation and for STS token exchange. AuthorizationURL string `toml:"authorization_url,omitempty"` + // DisableDynamicClientRegistration indicates whether dynamic client registration is disabled. + // If true, the .well-known endpoints will not expose the registration endpoint. + DisableDynamicClientRegistration bool `toml:"disable_dynamic_client_registration,omitempty"` + // OAuthScopes are the supported **client** scopes requested during the **client/frontend** OAuth flow. + OAuthScopes []string `toml:"oauth_scopes,omitempty"` // StsClientId is the OAuth client ID used for backend token exchange StsClientId string `toml:"sts_client_id,omitempty"` // StsClientSecret is the OAuth client secret used for backend token exchange diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 6fa81f0..39259d5 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -111,6 +111,7 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi } // Token exchange with OIDC provider sts := NewFromConfig(staticConfig, oidcProvider) + // TODO: Maybe the token had already been exchanged, if it has the right audience and scopes, we can skip this step. if err == nil && sts.IsEnabled() { var exchangedToken *oauth2.Token // If the token is valid, we can exchange it for a new token with the specified audience and scopes. diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 04cccf4..0ceaf2e 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -8,6 +8,7 @@ import ( "crypto/rsa" "flag" "fmt" + "io" "net" "net/http" "net/http/httptest" @@ -334,7 +335,28 @@ func TestWellKnownReverseProxy(t *testing.T) { }) } }) - // With Authorization URL configured + // With Authorization URL configured but invalid payload + invalidPayloadServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`NOT A JSON PAYLOAD`)) + })) + t.Cleanup(invalidPayloadServer.Close) + invalidPayloadConfig := &config.StaticConfig{AuthorizationURL: invalidPayloadServer.URL, RequireOAuth: true, ValidateToken: true} + testCaseWithContext(t, &httpContext{StaticConfig: invalidPayloadConfig}, func(ctx *httpContext) { + for _, path := range cases { + resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) + t.Cleanup(func() { _ = resp.Body.Close() }) + t.Run("Protected resource '"+path+"' with invalid Authorization URL payload returns 500 - Internal Server Error", func(t *testing.T) { + if err != nil { + t.Fatalf("Failed to get %s endpoint: %v", path, err) + } + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("Expected HTTP 500 Internal Server Error, got %d", resp.StatusCode) + } + }) + } + }) + // With Authorization URL configured and valid payload testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") { http.NotFound(w, r) @@ -344,7 +366,8 @@ func TestWellKnownReverseProxy(t *testing.T) { _, _ = w.Write([]byte(`{"issuer": "https://example.com","scopes_supported":["mcp-server"]}`)) })) t.Cleanup(testServer.Close) - testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) { + staticConfig := &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true} + testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) { for _, path := range cases { resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) t.Cleanup(func() { _ = resp.Body.Close() }) @@ -365,6 +388,87 @@ func TestWellKnownReverseProxy(t *testing.T) { }) } +func TestWellKnownOverrides(t *testing.T) { + cases := []string{ + ".well-known/oauth-authorization-server", + ".well-known/oauth-protected-resource", + ".well-known/openid-configuration", + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(` + { + "issuer": "https://localhost", + "registration_endpoint": "https://localhost/clients-registrations/openid-connect", + "require_request_uri_registration": true, + "scopes_supported":["scope-1", "scope-2"] + }`)) + })) + t.Cleanup(testServer.Close) + baseConfig := config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true} + // With Dynamic Client Registration disabled + disableDynamicRegistrationConfig := baseConfig + disableDynamicRegistrationConfig.DisableDynamicClientRegistration = true + testCaseWithContext(t, &httpContext{StaticConfig: &disableDynamicRegistrationConfig}, func(ctx *httpContext) { + for _, path := range cases { + resp, _ := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) + t.Cleanup(func() { _ = resp.Body.Close() }) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + t.Run("DisableDynamicClientRegistration removes registration_endpoint field", func(t *testing.T) { + if strings.Contains(string(body), "registration_endpoint") { + t.Error("Expected registration_endpoint to be removed, but it was found in the response") + } + }) + t.Run("DisableDynamicClientRegistration sets require_request_uri_registration = false", func(t *testing.T) { + if !strings.Contains(string(body), `"require_request_uri_registration":false`) { + t.Error("Expected require_request_uri_registration to be false, but it was not found in the response") + } + }) + t.Run("DisableDynamicClientRegistration includes/preserves scopes_supported", func(t *testing.T) { + if !strings.Contains(string(body), `"scopes_supported":["scope-1","scope-2"]`) { + t.Error("Expected scopes_supported to be present, but it was not found in the response") + } + }) + } + }) + // With overrides for OAuth scopes (client/frontend) + oAuthScopesConfig := baseConfig + oAuthScopesConfig.OAuthScopes = []string{"openid", "mcp-server"} + testCaseWithContext(t, &httpContext{StaticConfig: &oAuthScopesConfig}, func(ctx *httpContext) { + for _, path := range cases { + resp, _ := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path)) + t.Cleanup(func() { _ = resp.Body.Close() }) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + t.Run("OAuthScopes overrides scopes_supported", func(t *testing.T) { + if !strings.Contains(string(body), `"scopes_supported":["openid","mcp-server"]`) { + t.Errorf("Expected scopes_supported to be overridden, but original was preserved, response: %s", string(body)) + } + }) + t.Run("OAuthScopes preserves other fields", func(t *testing.T) { + if !strings.Contains(string(body), `"issuer":"https://localhost"`) { + t.Errorf("Expected issuer to be preserved, but got: %s", string(body)) + } + if !strings.Contains(string(body), `"registration_endpoint":"https://localhost`) { + t.Errorf("Expected registration_endpoint to be preserved, but got: %s", string(body)) + } + if !strings.Contains(string(body), `"require_request_uri_registration":true`) { + t.Error("Expected require_request_uri_registration to be true, but it was not found in the response") + } + }) + } + }) +} + func TestMiddlewareLogging(t *testing.T) { testCase(t, func(ctx *httpContext) { _, _ = http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress)) diff --git a/pkg/http/wellknown.go b/pkg/http/wellknown.go index c1e375e..0d80221 100644 --- a/pkg/http/wellknown.go +++ b/pkg/http/wellknown.go @@ -1,7 +1,8 @@ package http import ( - "io" + "encoding/json" + "fmt" "net/http" "strings" @@ -21,7 +22,9 @@ var WellKnownEndpoints = []string{ } type WellKnown struct { - authorizationUrl string + authorizationUrl string + scopesSupported []string + disableDynamicClientRegistration bool } var _ http.Handler = &WellKnown{} @@ -31,7 +34,11 @@ func WellKnownHandler(staticConfig *config.StaticConfig) http.Handler { if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") { authorizationUrl = strings.TrimSuffix(authorizationUrl, "/") } - return &WellKnown{authorizationUrl} + return &WellKnown{ + authorizationUrl: authorizationUrl, + disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration, + scopesSupported: staticConfig.OAuthScopes, + } } func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -50,16 +57,30 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) return } defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) + var resourceMetadata map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&resourceMetadata) if err != nil { http.Error(writer, "Failed to read response body: "+err.Error(), http.StatusInternalServerError) return } + if w.disableDynamicClientRegistration { + delete(resourceMetadata, "registration_endpoint") + resourceMetadata["require_request_uri_registration"] = false + } + if len(w.scopesSupported) > 0 { + resourceMetadata["scopes_supported"] = w.scopesSupported + } + body, err := json.Marshal(resourceMetadata) + if err != nil { + http.Error(writer, "Failed to marshal response body: "+err.Error(), http.StatusInternalServerError) + return + } for key, values := range resp.Header { for _, value := range values { writer.Header().Add(key, value) } } + writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) writer.WriteHeader(resp.StatusCode) _, _ = writer.Write(body) }