fastapi-native routes

This commit is contained in:
Shahar Abramov
2025-04-08 19:18:51 +03:00
parent 97ffd8e7c4
commit 7f543d9423
5 changed files with 105 additions and 25 deletions

View File

@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
# Add MCP server to the FastAPI app # Add MCP server to the FastAPI app
mcp = FastApiMCP( mcp = FastApiMCP(
items.app, items.app,
mount_path="/mcp",
name="Item API MCP", name="Item API MCP",
description="MCP server for the Item API", description="MCP server for the Item API",
base_url="http://localhost:8000", base_url="http://localhost:8000",

View File

@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
# Add MCP server to the FastAPI app # Add MCP server to the FastAPI app
mcp = FastApiMCP( mcp = FastApiMCP(
items.app, items.app,
mount_path="/mcp",
name="Item API MCP", name="Item API MCP",
description="MCP server for the Item API", description="MCP server for the Item API",
base_url="http://localhost:8000", base_url="http://localhost:8000",

View File

@@ -1,21 +1,25 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Optional, Any, List, Union, AsyncIterator from typing import Dict, Optional, Any, List, Union, AsyncIterator
from fastapi import FastAPI, Request from fastapi import FastAPI, Request, APIRouter
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from mcp.server.lowlevel.server import Server from mcp.server.lowlevel.server import Server
from mcp.server.sse import SseServerTransport
import mcp.types as types import mcp.types as types
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
from fastapi_mcp.execute import execute_api_tool from fastapi_mcp.execute import execute_api_tool
from fastapi_mcp.transport.sse import FastApiSseTransport
from logging import getLogger
logger = getLogger(__name__)
class FastApiMCP: class FastApiMCP:
def __init__( def __init__(
self, self,
fastapi: FastAPI, fastapi: FastAPI,
mount_path: str = "/mcp",
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
@@ -29,7 +33,6 @@ class FastApiMCP:
self.name = name self.name = name
self.description = description self.description = description
self._mount_path = mount_path
self._base_url = base_url self._base_url = base_url
self._describe_all_responses = describe_all_responses self._describe_all_responses = describe_all_responses
self._describe_full_response_schema = describe_full_response_schema self._describe_full_response_schema = describe_full_response_schema
@@ -41,10 +44,12 @@ class FastApiMCP:
Create an MCP server from the FastAPI app. Create an MCP server from the FastAPI app.
Args: Args:
app: The FastAPI application fastapi: The FastAPI application
name: Name for the MCP server (defaults to app.title) name: Name for the MCP server (defaults to app.title)
description: Description for the MCP server (defaults to app.description) description: Description for the MCP server (defaults to app.description)
base_url: Base URL for API requests (defaults to http://localhost:$PORT) base_url: Base URL for API requests. If not provided, the base URL will be determined from the
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
as the root path would be different when the app is deployed.
describe_all_responses: Whether to include all possible response schemas in tool descriptions describe_all_responses: Whether to include all possible response schemas in tool descriptions
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
@@ -126,37 +131,42 @@ class FastApiMCP:
return mcp_server return mcp_server
def mount(self) -> None: def mount(self, router: Optional[FastAPI | APIRouter] = None, mount_path: str = "/mcp") -> None:
""" """
Mount the MCP server to the FastAPI app. Mount the MCP server to the FastAPI app.
Args: Args:
app: The FastAPI application router: The FastAPI app or APIRouter to mount the MCP server to. If not provided, the MCP
mcp_server: The MCP server to mount server will be mounted to the FastAPI app.
operation_map: A mapping of operation IDs to operation details
mount_path: Path where the MCP server will be mounted mount_path: Path where the MCP server will be mounted
base_url: Base URL for API requests
""" """
# Normalize mount path # Normalize mount path
if not self._mount_path.startswith("/"): if not mount_path.startswith("/"):
self._mount_path = f"/{self._mount_path}" mount_path = f"/{mount_path}"
if self._mount_path.endswith("/"): if mount_path.endswith("/"):
self._mount_path = self._mount_path[:-1] mount_path = mount_path[:-1]
if not router:
router = self.fastapi
# Create SSE transport for MCP messages # Create SSE transport for MCP messages
sse_transport = SseServerTransport(f"{self._mount_path}/messages/") sse_transport = FastApiSseTransport(f"{mount_path}/messages/")
# Define MCP connection handler # Route for MCP connection
@router.get(mount_path)
async def handle_mcp_connection(request: Request): async def handle_mcp_connection(request: Request):
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (reader, writer):
await self.mcp_server.run( await self.mcp_server.run(
streams[0], reader,
streams[1], writer,
self.mcp_server.create_initialization_options( self.mcp_server.create_initialization_options(
notification_options=None, experimental_capabilities={} notification_options=None, experimental_capabilities={}
), ),
) )
# Mount the MCP connection handler # Route for MCP messages
self.fastapi.get(self._mount_path)(handle_mcp_connection) @router.post(f"{mount_path}/messages/")
self.fastapi.mount(f"{self._mount_path}/messages/", app=sse_transport.handle_post_message) async def handle_post_message(request: Request):
await sse_transport.handle_fastapi_post_message(request)
logger.info(f"MCP server listening at {mount_path}")

View File

View File

@@ -0,0 +1,72 @@
from uuid import UUID
from logging import getLogger
from fastapi import Request, Response
from pydantic import ValidationError
from mcp.server.sse import SseServerTransport
from mcp.types import JSONRPCMessage
logger = getLogger(__name__)
class FastApiSseTransport(SseServerTransport):
async def handle_fastapi_post_message(self, request: Request) -> None:
"""
A reimplementation of the handle_post_message method of SseServerTransport
that integrates better with FastAPI.
A few good reasons for doing this:
1. Avoid mounting a whole Starlette app and instead use a more FastAPI-native
approach. Mounting has some known issues and limitations.
2. Avoid re-constructing the scope, receive, and send from the request, as done
in the original implementation.
The combination of mounting a whole Starlette app and reconstructing the scope
and send from the request proved to be especially error-prone for us when using
tracing tools like Sentry, which had destructive effects on the request object
when using the original implementation.
"""
logger.debug("Handling POST message")
scope = request.scope
receive = request.receive
send = request._send
session_id_param = request.query_params.get("session_id")
if session_id_param is None:
logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return await response(scope, receive, send)
try:
session_id = UUID(hex=session_id_param)
logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
writer = self._read_stream_writers.get(session_id)
if not writer:
logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)
body = await request.body()
logger.debug(f"Received JSON: {body.decode()}")
try:
message = JSONRPCMessage.model_validate_json(body)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
await writer.send(err)
return
logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)