mirror of
https://github.com/openshift/openshift-mcp-server.git
synced 2025-10-17 14:27:48 +03:00
feat(auth): introduce OIDC token verification if authorization-url is specified (176)
Pass correct audience --- Validate server and authorization url via url.Parse --- Import go-oidc/v3 --- Wire initialized oidc provider if authorization url is set --- Wire oidc issuer validation
This commit is contained in:
2
go.mod
2
go.mod
@@ -4,6 +4,7 @@ go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.5.0
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/mark3labs/mcp-go v0.33.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
@@ -51,6 +52,7 @@ require (
|
||||
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
|
||||
github.com/go-errors/errors v1.4.2 // indirect
|
||||
github.com/go-gorp/gorp/v3 v3.1.0 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
github.com/go-openapi/jsonreference v0.20.2 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -44,6 +44,8 @@ github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
|
||||
github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
@@ -92,6 +94,8 @@ github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxI
|
||||
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||
github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs=
|
||||
github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"k8s.io/klog/v2"
|
||||
|
||||
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
|
||||
@@ -31,37 +32,83 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp
|
||||
return
|
||||
}
|
||||
|
||||
audience := Audience
|
||||
if serverURL != "" {
|
||||
audience = serverURL
|
||||
}
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
|
||||
if serverURL == "" {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
|
||||
}
|
||||
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
audience := Audience
|
||||
if serverURL != "" {
|
||||
audience = serverURL
|
||||
}
|
||||
|
||||
err := validateJWTToken(token, audience)
|
||||
// Validate the token offline for simple sanity check
|
||||
// Because missing expected audience and expired tokens must be
|
||||
// rejected already.
|
||||
claims, err := validateJWTToken(token, audience)
|
||||
if err != nil {
|
||||
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
|
||||
if serverURL == "" {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
|
||||
}
|
||||
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token using Kubernetes TokenReview API
|
||||
_, _, err = mcpServer.VerifyToken(r.Context(), token, Audience)
|
||||
oidcProvider := mcpServer.GetOIDCProvider()
|
||||
if oidcProvider != nil {
|
||||
// If OIDC Provider is configured, this token must be validated against it.
|
||||
if err := validateTokenWithOIDC(r.Context(), oidcProvider, token, audience); err != nil {
|
||||
klog.V(1).Infof("Authentication failed - OIDC token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
if serverURL == "" {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
|
||||
}
|
||||
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Scopes are likely to be used for authorization.
|
||||
scopes := claims.GetScopes()
|
||||
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
|
||||
|
||||
// Now, there are a couple of options:
|
||||
// 1. If there is no authorization url configured for this MCP Server,
|
||||
// that means this token will be used against the Kubernetes API Server.
|
||||
// So that we need to validate the token using Kubernetes TokenReview API beforehand.
|
||||
// 2. If there is an authorization url configured for this MCP Server,
|
||||
// that means up to this point, the token is validated against the OIDC Provider already.
|
||||
// 2. a. If this is the only token in the headers, this validated token
|
||||
// is supposed to be used against the Kubernetes API Server as well. Therefore,
|
||||
// TokenReview request must succeed.
|
||||
// 2. b. If this is not the only token in the headers, the token in here is used
|
||||
// only for authentication and authorization. Therefore, we need to send TokenReview request
|
||||
// with the other token in the headers (TODO: still need to validate aud and exp of this token separately).
|
||||
_, _, err = mcpServer.VerifyTokenAPIServer(r.Context(), token, audience)
|
||||
if err != nil {
|
||||
klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
|
||||
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
|
||||
if serverURL == "" {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
|
||||
} else {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
|
||||
}
|
||||
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
@@ -72,32 +119,60 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp
|
||||
}
|
||||
|
||||
type JWTClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Audience []string `json:"aud"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
Issuer string `json:"iss"`
|
||||
Audience any `json:"aud"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// validateJWTToken validates basic JWT claims without signature verification
|
||||
func validateJWTToken(token, audience string) error {
|
||||
func (c *JWTClaims) GetScopes() []string {
|
||||
if c.Scope == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Fields(c.Scope)
|
||||
}
|
||||
|
||||
func (c *JWTClaims) ContainsAudience(audience string) bool {
|
||||
switch aud := c.Audience.(type) {
|
||||
case string:
|
||||
return aud == audience
|
||||
case []interface{}:
|
||||
for _, a := range aud {
|
||||
if str, ok := a.(string); ok && str == audience {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, a := range aud {
|
||||
if a == audience {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// validateJWTToken validates basic JWT claims without signature verification and returns the claims
|
||||
func validateJWTToken(token, audience string) (*JWTClaims, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid JWT token format")
|
||||
return nil, fmt.Errorf("invalid JWT token format")
|
||||
}
|
||||
|
||||
claims, err := parseJWTClaims(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT claims: %v", err)
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %v", err)
|
||||
}
|
||||
|
||||
if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt {
|
||||
return fmt.Errorf("token expired")
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, audience) {
|
||||
return fmt.Errorf("token audience mismatch: %v", claims.Audience)
|
||||
if !claims.ContainsAudience(audience) {
|
||||
return nil, fmt.Errorf("token audience mismatch: %v", claims.Audience)
|
||||
}
|
||||
|
||||
return nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func parseJWTClaims(payload string) (*JWTClaims, error) {
|
||||
@@ -118,3 +193,16 @@ func parseJWTClaims(payload string) (*JWTClaims, error) {
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {
|
||||
verifier := provider.Verifier(&oidc.Config{
|
||||
ClientID: audience,
|
||||
})
|
||||
|
||||
_, err := verifier.Verify(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("JWT token verification failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -29,13 +29,9 @@ func TestParseJWTClaims(t *testing.T) {
|
||||
}
|
||||
|
||||
expectedAudiences := []string{"https://kubernetes.default.svc.cluster.local", "kubernetes-mcp-server"}
|
||||
if len(claims.Audience) != 2 {
|
||||
t.Errorf("expected 2 audiences, got %d", len(claims.Audience))
|
||||
}
|
||||
|
||||
for i, expected := range expectedAudiences {
|
||||
if i >= len(claims.Audience) || claims.Audience[i] != expected {
|
||||
t.Errorf("expected audience[%d] to be %s, got %s", i, expected, claims.Audience[i])
|
||||
for _, expected := range expectedAudiences {
|
||||
if !claims.ContainsAudience(expected) {
|
||||
t.Errorf("expected audience to contain %s", expected)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +44,7 @@ func TestParseJWTClaims(t *testing.T) {
|
||||
// Create a payload that needs padding
|
||||
testClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"test-audience"},
|
||||
Audience: "test-audience",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
@@ -98,7 +94,7 @@ func TestValidateJWTToken(t *testing.T) {
|
||||
t.Run("invalid token format - not enough parts", func(t *testing.T) {
|
||||
invalidToken := "header.payload"
|
||||
|
||||
err := validateJWTToken(invalidToken, "test")
|
||||
_, err := validateJWTToken(invalidToken, "test")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid token format, got nil")
|
||||
}
|
||||
@@ -112,15 +108,15 @@ func TestValidateJWTToken(t *testing.T) {
|
||||
// Create an expired token
|
||||
expiredClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"kubernetes-mcp-server"},
|
||||
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // 1 hour ago
|
||||
Audience: "kubernetes-mcp-server",
|
||||
ExpiresAt: time.Now().Add(-time.Hour).Unix(),
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(expiredClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
expiredToken := "header." + payload + ".signature"
|
||||
|
||||
err := validateJWTToken(expiredToken, "kubernetes-mcp-server")
|
||||
_, err := validateJWTToken(expiredToken, "kubernetes-mcp-server")
|
||||
if err == nil {
|
||||
t.Error("expected error for expired token, got nil")
|
||||
}
|
||||
@@ -136,23 +132,42 @@ func TestValidateJWTToken(t *testing.T) {
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"other-audience", "kubernetes-mcp-server", "another-audience"},
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
Scope: "read write admin",
|
||||
}
|
||||
|
||||
jsonBytes, _ := json.Marshal(multiAudClaims)
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
multiAudToken := "header." + payload + ".signature"
|
||||
|
||||
err := validateJWTToken(multiAudToken, "kubernetes-mcp-server")
|
||||
claims, err := validateJWTToken(multiAudToken, "kubernetes-mcp-server")
|
||||
if err != nil {
|
||||
t.Errorf("expected no error for token with multiple audiences, got %v", err)
|
||||
}
|
||||
if claims == nil {
|
||||
t.Error("expected claims to be returned, got nil")
|
||||
}
|
||||
if claims.Issuer != "test-issuer" {
|
||||
t.Errorf("expected issuer 'test-issuer', got %s", claims.Issuer)
|
||||
}
|
||||
|
||||
// Test scope parsing
|
||||
scopes := claims.GetScopes()
|
||||
expectedScopes := []string{"read", "write", "admin"}
|
||||
if len(scopes) != 3 {
|
||||
t.Errorf("expected 3 scopes, got %d", len(scopes))
|
||||
}
|
||||
for i, expectedScope := range expectedScopes {
|
||||
if i >= len(scopes) || scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("audience mismatch", func(t *testing.T) {
|
||||
// Create a token with wrong audience
|
||||
wrongAudClaims := JWTClaims{
|
||||
Issuer: "test-issuer",
|
||||
Audience: []string{"wrong-audience"},
|
||||
Audience: "wrong-audience",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
@@ -160,7 +175,7 @@ func TestValidateJWTToken(t *testing.T) {
|
||||
payload := base64.URLEncoding.EncodeToString(jsonBytes)
|
||||
wrongAudToken := "header." + payload + ".signature"
|
||||
|
||||
err := validateJWTToken(wrongAudToken, "audience")
|
||||
_, err := validateJWTToken(wrongAudToken, "audience")
|
||||
if err == nil {
|
||||
t.Error("expected error for token with wrong audience, got nil")
|
||||
}
|
||||
@@ -270,3 +285,85 @@ func TestAuthorizationMiddleware(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTClaimsGetScopes(t *testing.T) {
|
||||
t.Run("single scope", func(t *testing.T) {
|
||||
claims := &JWTClaims{Scope: "read"}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read"}
|
||||
|
||||
if len(scopes) != 1 {
|
||||
t.Errorf("expected 1 scope, got %d", len(scopes))
|
||||
}
|
||||
if scopes[0] != expected[0] {
|
||||
t.Errorf("expected scope 'read', got '%s'", scopes[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple scopes", func(t *testing.T) {
|
||||
claims := &JWTClaims{Scope: "read write admin"}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read", "write", "admin"}
|
||||
|
||||
if len(scopes) != 3 {
|
||||
t.Errorf("expected 3 scopes, got %d", len(scopes))
|
||||
}
|
||||
|
||||
for i, expectedScope := range expected {
|
||||
if i >= len(scopes) || scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("scopes with extra whitespace", func(t *testing.T) {
|
||||
claims := &JWTClaims{Scope: " read write admin "}
|
||||
scopes := claims.GetScopes()
|
||||
expected := []string{"read", "write", "admin"}
|
||||
|
||||
if len(scopes) != 3 {
|
||||
t.Errorf("expected 3 scopes, got %d", len(scopes))
|
||||
}
|
||||
|
||||
for i, expectedScope := range expected {
|
||||
if i >= len(scopes) || scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope[%d] to be '%s', got '%s'", i, expectedScope, scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTClaimsContainsAudience(t *testing.T) {
|
||||
t.Run("single string audience", func(t *testing.T) {
|
||||
claims := &JWTClaims{Audience: "test-audience"}
|
||||
|
||||
if !claims.ContainsAudience("test-audience") {
|
||||
t.Error("expected ContainsAudience to return true for matching audience")
|
||||
}
|
||||
|
||||
if claims.ContainsAudience("other-audience") {
|
||||
t.Error("expected ContainsAudience to return false for non-matching audience")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array audience", func(t *testing.T) {
|
||||
claims := &JWTClaims{Audience: []string{"aud1", "aud2", "aud3"}}
|
||||
|
||||
testCases := []struct {
|
||||
audience string
|
||||
expected bool
|
||||
}{
|
||||
{"aud1", true},
|
||||
{"aud2", true},
|
||||
{"aud3", true},
|
||||
{"aud4", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if claims.ContainsAudience(tc.audience) != tc.expected {
|
||||
t.Errorf("expected ContainsAudience(%s) to return %v", tc.audience, tc.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
|
||||
)
|
||||
|
||||
const oauthProtectedResourceEndpoint = "/.well-known/oauth-protected-resource"
|
||||
|
||||
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig) error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
@@ -36,7 +38,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc(oauthProtectedResourceEndpoint, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var authServers []string
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"k8s.io/cli-runtime/pkg/genericiooptions"
|
||||
@@ -196,20 +198,28 @@ func (m *MCPServerOptions) Validate() error {
|
||||
if !m.StaticConfig.RequireOAuth && (m.StaticConfig.AuthorizationURL != "" || m.StaticConfig.ServerURL != "") {
|
||||
return fmt.Errorf("authorization-url and server-url are only valid if require-oauth is enabled")
|
||||
}
|
||||
if m.StaticConfig.AuthorizationURL != "" &&
|
||||
!strings.HasPrefix(m.StaticConfig.AuthorizationURL, "https://") {
|
||||
if strings.HasPrefix(m.StaticConfig.AuthorizationURL, "http://") {
|
||||
if m.StaticConfig.AuthorizationURL != "" {
|
||||
u, err := url.Parse(m.StaticConfig.AuthorizationURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return fmt.Errorf("--authorization-url must be a valid URL")
|
||||
}
|
||||
if u.Scheme == "http" {
|
||||
klog.Warningf("authorization-url is using http://, this is not recommended production use")
|
||||
} else {
|
||||
return fmt.Errorf("authorization-url must start with https://")
|
||||
}
|
||||
}
|
||||
if m.StaticConfig.ServerURL != "" &&
|
||||
!strings.HasPrefix(m.StaticConfig.ServerURL, "https://") {
|
||||
if strings.HasPrefix(m.StaticConfig.ServerURL, "http://") {
|
||||
if m.StaticConfig.ServerURL != "" {
|
||||
u, err := url.Parse(m.StaticConfig.ServerURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return fmt.Errorf("--server-url must be a valid URL")
|
||||
}
|
||||
if u.Scheme == "http" {
|
||||
klog.Warningf("server-url is using http://, this is not recommended production use")
|
||||
} else {
|
||||
return fmt.Errorf("server-url must start with https://")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -235,10 +245,21 @@ func (m *MCPServerOptions) Run() error {
|
||||
_, _ = fmt.Fprintf(m.Out, "%s\n", version.Version)
|
||||
return nil
|
||||
}
|
||||
|
||||
var oidcProvider *oidc.Provider
|
||||
if m.StaticConfig.AuthorizationURL != "" {
|
||||
provider, err := oidc.NewProvider(context.TODO(), m.StaticConfig.AuthorizationURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to setup OIDC provider: %w", err)
|
||||
}
|
||||
oidcProvider = provider
|
||||
}
|
||||
|
||||
mcpServer, err := mcp.NewServer(mcp.Configuration{
|
||||
Profile: profile,
|
||||
ListOutput: listOutput,
|
||||
StaticConfig: m.StaticConfig,
|
||||
OIDCProvider: oidcProvider,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to initialize MCP server: %w\n", err)
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestAuthorizationURL(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid authorization-url without protocol, got nil")
|
||||
}
|
||||
expected := "authorization-url must start with https://"
|
||||
expected := "--authorization-url must be a valid URL"
|
||||
if !strings.Contains(err.Error(), expected) {
|
||||
t.Fatalf("Expected error to contain %s, got %s", expected, err.Error())
|
||||
}
|
||||
@@ -265,7 +265,7 @@ func TestServerURL(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid server-url without protocol, got nil")
|
||||
}
|
||||
expected := "server-url must start with https://"
|
||||
expected := "--server-url must be a valid URL"
|
||||
if !strings.Contains(err.Error(), expected) {
|
||||
t.Fatalf("Expected error to contain %s, got %s", expected, err.Error())
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
authenticationapiv1 "k8s.io/api/authentication/v1"
|
||||
@@ -18,8 +19,9 @@ import (
|
||||
)
|
||||
|
||||
type Configuration struct {
|
||||
Profile Profile
|
||||
ListOutput output.Output
|
||||
Profile Profile
|
||||
ListOutput output.Output
|
||||
OIDCProvider *oidc.Provider
|
||||
|
||||
StaticConfig *config.StaticConfig
|
||||
}
|
||||
@@ -105,9 +107,9 @@ func (s *Server) ServeHTTP(httpServer *http.Server) *server.StreamableHTTPServer
|
||||
return server.NewStreamableHTTPServer(s.server, options...)
|
||||
}
|
||||
|
||||
// VerifyToken verifies the given token with the audience by
|
||||
// VerifyTokenAPIServer verifies the given token with the audience by
|
||||
// sending an TokenReview request to API Server.
|
||||
func (s *Server) VerifyToken(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) {
|
||||
func (s *Server) VerifyTokenAPIServer(ctx context.Context, token string, audience string) (*authenticationapiv1.UserInfo, []string, error) {
|
||||
if s.k == nil {
|
||||
return nil, nil, fmt.Errorf("kubernetes manager is not initialized")
|
||||
}
|
||||
@@ -122,6 +124,13 @@ func (s *Server) GetKubernetesAPIServerHost() string {
|
||||
return s.k.GetAPIServerHost()
|
||||
}
|
||||
|
||||
func (s *Server) GetOIDCProvider() *oidc.Provider {
|
||||
if s.configuration.OIDCProvider == nil {
|
||||
return nil
|
||||
}
|
||||
return s.configuration.OIDCProvider
|
||||
}
|
||||
|
||||
func (s *Server) Close() {
|
||||
if s.k != nil {
|
||||
s.k.Close()
|
||||
|
||||
Reference in New Issue
Block a user