diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 27f842e..1c256b1 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -22,7 +22,7 @@ const ( func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *oidc.Provider, mcpServer *mcp.Server) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == healthEndpoint || r.URL.Path == oauthProtectedResourceEndpoint { + if r.URL.Path == healthEndpoint || r.URL.Path == oauthProtectedResourceEndpoint || r.URL.Path == oauthAuthorizationServerEndpoint { next.ServeHTTP(w, r) return } diff --git a/pkg/http/http.go b/pkg/http/http.go index ceee09b..6ba625a 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -44,6 +44,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat mux.HandleFunc(healthEndpoint, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + mux.HandleFunc(oauthAuthorizationServerEndpoint, OAuthAuthorizationServerHandler(staticConfig)) mux.HandleFunc(oauthProtectedResourceEndpoint, OAuthProtectedResourceHandler(mcpServer, staticConfig)) ctx, cancel := context.WithCancel(ctx) diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index 8af780c..98cf127 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -286,6 +286,36 @@ func TestHealthCheck(t *testing.T) { }) } +func TestWellKnownOAuthAuthorizationServer(t *testing.T) { + // Simple http server to mock the authorization server + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/oauth-authorization-server" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"issuer": "https://example.com"}`)) + })) + t.Cleanup(testServer.Close) + testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true}}, func(ctx *httpContext) { + resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-authorization-server", ctx.HttpAddress)) + t.Cleanup(func() { _ = resp.Body.Close() }) + t.Run("Exposes .well-known/oauth-authorization-server endpoint", func(t *testing.T) { + if err != nil { + t.Fatalf("Failed to get .well-known/oauth-authorization-server endpoint: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode) + } + }) + t.Run(".well-known/oauth-authorization-server returns application/json content type", func(t *testing.T) { + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type")) + } + }) + }) +} + func TestWellKnownOAuthProtectedResource(t *testing.T) { testCase(t, func(ctx *httpContext) { resp, err := http.Get(fmt.Sprintf("http://%s/.well-known/oauth-protected-resource", ctx.HttpAddress)) diff --git a/pkg/http/wellknown.go b/pkg/http/wellknown.go index 371901a..e946047 100644 --- a/pkg/http/wellknown.go +++ b/pkg/http/wellknown.go @@ -2,6 +2,7 @@ package http import ( "encoding/json" + "io" "net/http" "github.com/containers/kubernetes-mcp-server/pkg/config" @@ -9,9 +10,42 @@ import ( ) const ( - oauthProtectedResourceEndpoint = "/.well-known/oauth-protected-resource" + oauthAuthorizationServerEndpoint = "/.well-known/oauth-authorization-server" + oauthProtectedResourceEndpoint = "/.well-known/oauth-protected-resource" ) +func OAuthAuthorizationServerHandler(staticConfig *config.StaticConfig) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if staticConfig.AuthorizationURL == "" { + http.Error(w, "Authorization URL is not configured", http.StatusNotFound) + return + } + req, err := http.NewRequest(r.Method, staticConfig.AuthorizationURL+oauthAuthorizationServerEndpoint, nil) + if err != nil { + http.Error(w, "Failed to create request: "+err.Error(), http.StatusInternalServerError) + return + } + resp, err := http.DefaultClient.Do(req.WithContext(r.Context())) + if err != nil { + http.Error(w, "Failed to perform request: "+err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = resp.Body.Close() }() + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "Failed to read response body: "+err.Error(), http.StatusInternalServerError) + return + } + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = w.Write(body) + } +} + func OAuthProtectedResourceHandler(mcpServer *mcp.Server, staticConfig *config.StaticConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json")