WIP - move tool call handler registration away from mount function

This commit is contained in:
Shahar Abramov
2025-04-08 15:45:22 +03:00
parent 45600dcbd0
commit 1bf481f8dc
2 changed files with 51 additions and 39 deletions

View File

@@ -7,7 +7,8 @@ without the intermediate step of dynamically generating Python functions.
import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from typing import Any, Dict, List, Optional, Union, Tuple, AsyncIterator
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI
@@ -349,6 +350,7 @@ def create_mcp_server(
app: FastAPI,
name: Optional[str] = None,
description: Optional[str] = None,
base_url: Optional[str] = None,
describe_all_responses: bool = False,
describe_full_response_schema: bool = False,
) -> tuple[Server, Dict[str, Dict[str, Any]]]:
@@ -359,6 +361,7 @@ def create_mcp_server(
app: The FastAPI application
name: Name for the MCP server (defaults to app.title)
description: Description for the MCP server (defaults to app.description)
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
describe_all_responses: Whether to include all possible response schemas in tool descriptions
describe_full_response_schema: Whether to include full response schema in tool descriptions
@@ -387,15 +390,57 @@ def create_mcp_server(
describe_full_response_schema=describe_full_response_schema,
)
# Determine base URL if not provided
if not base_url:
# Try to determine the base URL from FastAPI config
if hasattr(app, "root_path") and app.root_path:
base_url = app.root_path
else:
# Default to localhost with FastAPI default port
port = 8000
for route in app.routes:
if hasattr(route, "app") and hasattr(route.app, "port"):
port = route.app.port
break
base_url = f"http://localhost:{port}"
# Normalize base URL
if base_url.endswith("/"):
base_url = base_url[:-1]
# Create the MCP server
mcp_server: Server = Server(server_name, server_description)
# Create a lifespan context manager to store the base_url and operation_map
@asynccontextmanager
async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]:
# Store context data that will be available to all server handlers
context = {"base_url": base_url, "operation_map": operation_map}
yield context
# Use our custom lifespan
mcp_server.lifespan = server_lifespan
# Register handlers for tools
@mcp_server.list_tools()
async def handle_list_tools() -> List[types.Tool]:
"""Handler for the tools/list request"""
return tools
# Register the tool call handler
@mcp_server.call_tool()
async def handle_call_tool(
name: str, arguments: Dict[str, Any]
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
"""Handler for the tools/call request"""
# Get context from server lifespan
ctx = mcp_server.request_context
base_url = ctx.lifespan_context["base_url"]
operation_map = ctx.lifespan_context["operation_map"]
# Execute the tool
return await execute_http_tool(base_url, name, arguments, operation_map)
return mcp_server, operation_map
@@ -422,35 +467,6 @@ def mount_mcp_server(
if mount_path.endswith("/"):
mount_path = mount_path[:-1]
# Determine base URL if not provided
if not base_url:
# Try to determine the base URL from FastAPI config
if hasattr(app, "root_path") and app.root_path:
base_url = app.root_path
else:
# Default to localhost with FastAPI default port
port = 8000
for route in app.routes:
if hasattr(route, "app") and hasattr(route.app, "port"):
port = route.app.port
break
base_url = f"http://localhost:{port}"
# Normalize base URL
if base_url.endswith("/"):
base_url = base_url[:-1]
# Create final base URL to use for HTTP requests
final_base_url = base_url
# Register the tool call handler
@mcp_server.call_tool()
async def handle_call_tool(
name: str, arguments: Dict[str, Any]
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
"""Handler for the tools/call request"""
return await execute_http_tool(final_base_url, name, arguments, operation_map)
# Create SSE transport for MCP messages
from mcp.server.sse import SseServerTransport
from fastapi import Request
@@ -500,6 +516,7 @@ def add_mcp_server(
app,
name,
description,
base_url,
describe_all_responses=describe_all_responses,
describe_full_response_schema=describe_full_response_schema,
)

View File

@@ -19,6 +19,7 @@ def create_mcp_server(
name: Optional[str] = None,
description: Optional[str] = None,
capabilities: Optional[Dict[str, Any]] = None,
base_url: Optional[str] = None,
describe_all_responses: bool = False,
describe_full_response_schema: bool = False,
) -> Server:
@@ -30,6 +31,7 @@ def create_mcp_server(
name: Name for the MCP server (defaults to app.title)
description: Description for the MCP server (defaults to app.description)
capabilities: Optional capabilities for the MCP server (ignored in direct conversion)
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
describe_all_responses: Whether to include all possible response schemas in tool descriptions. Recommended to keep False, as the LLM will probably derive if there is an error.
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. Recommended to keep False, as examples are more LLM friendly, and save tokens.
@@ -41,6 +43,7 @@ def create_mcp_server(
app,
name,
description,
base_url,
describe_all_responses=describe_all_responses,
describe_full_response_schema=describe_full_response_schema,
)
@@ -53,8 +56,6 @@ def mount_mcp_server(
mcp_server: Server,
mount_path: str = "/mcp",
base_url: Optional[str] = None,
describe_all_responses: bool = False,
describe_full_response_schema: bool = False,
) -> None:
"""
Mount an MCP server to a FastAPI app.
@@ -64,8 +65,6 @@ def mount_mcp_server(
mcp_server: The MCP server to mount
mount_path: Path where the MCP server will be mounted
base_url: Base URL for API requests
describe_all_responses: Whether to include all possible response schemas in tool descriptions. Recommended to keep False, as the LLM will probably derive if there is an error.
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. Recommended to keep False, as examples are more LLM friendly, and save tokens.
"""
# Get OpenAPI schema from FastAPI app for operation mapping
from fastapi.openapi.utils import get_openapi
@@ -81,11 +80,7 @@ def mount_mcp_server(
# Extract operation map for HTTP calls
# The function returns a tuple (tools, operation_map)
result = convert_openapi_to_mcp_tools(
openapi_schema,
describe_all_responses=describe_all_responses,
describe_full_response_schema=describe_full_response_schema,
)
result = convert_openapi_to_mcp_tools(openapi_schema)
operation_map = result[1]
# Mount using the direct approach