diff --git a/.gitignore b/.gitignore index 8b36a25..12b624e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +_output/ .idea/ .vscode/ .docusaurus/ diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 5b84097..938a8c7 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -69,7 +69,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider * http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized) return } - + if oidcProvider != nil { // If OIDC Provider is configured, this token must be validated against it. if err := validateTokenWithOIDC(r.Context(), oidcProvider, token, audience); err != nil { diff --git a/pkg/mcp/common_test.go b/pkg/mcp/common_test.go index d1b0104..f0a6f84 100644 --- a/pkg/mcp/common_test.go +++ b/pkg/mcp/common_test.go @@ -1,13 +1,18 @@ package mcp import ( + "bytes" "context" "encoding/json" + "flag" "fmt" + "k8s.io/klog/v2" + "k8s.io/klog/v2/textlogger" "net/http/httptest" "os" "path/filepath" "runtime" + "strconv" "testing" "time" @@ -97,6 +102,7 @@ func TestMain(m *testing.M) { type mcpContext struct { profile Profile listOutput output.Output + logLevel int staticConfig *config.StaticConfig clientOptions []transport.ClientOption @@ -108,6 +114,8 @@ type mcpContext struct { mcpServer *Server mcpHttpServer *httptest.Server mcpClient *client.Client + klogState klog.State + logBuffer bytes.Buffer } func (c *mcpContext) beforeEach(t *testing.T) { @@ -130,6 +138,13 @@ func (c *mcpContext) beforeEach(t *testing.T) { if c.before != nil { c.before(c) } + // Set up logging + c.klogState = klog.CaptureState() + flags := flag.NewFlagSet("test", flag.ContinueOnError) + klog.InitFlags(flags) + _ = flags.Set("v", strconv.Itoa(c.logLevel)) + klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(c.logLevel), textlogger.Output(&c.logBuffer)))) + // MCP Server if c.mcpServer, err = NewServer(Configuration{ Profile: c.profile, ListOutput: c.listOutput, @@ -143,6 +158,7 @@ func (c *mcpContext) beforeEach(t *testing.T) { t.Fatal(err) return } + // MCP Client if err = c.mcpClient.Start(c.ctx); err != nil { t.Fatal(err) return @@ -165,6 +181,7 @@ func (c *mcpContext) afterEach() { c.mcpServer.Close() _ = c.mcpClient.Close() c.mcpHttpServer.Close() + c.klogState.Restore() } func testCase(t *testing.T, test func(c *mcpContext)) { diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 2c403c8..83bca6f 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -3,6 +3,7 @@ package mcp import ( "context" "fmt" + "k8s.io/klog/v2" "net/http" "slices" @@ -56,6 +57,7 @@ func NewServer(configuration Configuration) (*Server, error) { server.WithPromptCapabilities(true), server.WithToolCapabilities(true), server.WithLogging(), + server.WithToolHandlerMiddleware(toolCallLoggingMiddleware), ), } if err := s.reloadKubernetesClient(); err != nil { @@ -165,3 +167,10 @@ func contextFunc(ctx context.Context, r *http.Request) context.Context { return ctx } + +func toolCallLoggingMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFunc { + return func(ctx context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) { + klog.V(5).Infof("mcp tool call: %s(%v)", ctr.Params.Name, ctr.Params.Arguments) + return next(ctx, ctr) + } +} diff --git a/pkg/mcp/mcp_tools_test.go b/pkg/mcp/mcp_tools_test.go index bd7d36b..03dc2a2 100644 --- a/pkg/mcp/mcp_tools_test.go +++ b/pkg/mcp/mcp_tools_test.go @@ -1,10 +1,11 @@ package mcp import ( - "k8s.io/utils/ptr" - "testing" - "github.com/mark3labs/mcp-go/mcp" + "k8s.io/utils/ptr" + "regexp" + "strings" + "testing" "github.com/manusa/kubernetes-mcp-server/pkg/config" ) @@ -116,3 +117,27 @@ func TestDisabledTools(t *testing.T) { }) }) } + +func TestToolCallLogging(t *testing.T) { + testCaseWithContext(t, &mcpContext{logLevel: 5}, func(c *mcpContext) { + _, _ = c.callTool("configuration_view", map[string]interface{}{ + "minified": false, + }) + t.Run("Logs tool name", func(t *testing.T) { + expectedLog := "mcp tool call: configuration_view(" + if !strings.Contains(c.logBuffer.String(), expectedLog) { + t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String()) + } + }) + t.Run("Logs tool call arguments", func(t *testing.T) { + expected := `"mcp tool call: configuration_view\((.+)\)"` + m := regexp.MustCompile(expected).FindStringSubmatch(c.logBuffer.String()) + if len(m) != 2 { + t.Fatalf("Expected log entry to contain arguments, got %s", c.logBuffer.String()) + } + if m[1] != "map[minified:false]" { + t.Errorf("Expected log arguments to be 'map[minified:false]', got %s", m[1]) + } + }) + }) +}