fastapi-native routes

This commit is contained in:
Shahar Abramov
2025-04-08 19:18:51 +03:00
parent 97ffd8e7c4
commit 7f543d9423
5 changed files with 105 additions and 25 deletions

View File

@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
# Add MCP server to the FastAPI app
mcp = FastApiMCP(
items.app,
mount_path="/mcp",
name="Item API MCP",
description="MCP server for the Item API",
base_url="http://localhost:8000",

View File

@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
# Add MCP server to the FastAPI app
mcp = FastApiMCP(
items.app,
mount_path="/mcp",
name="Item API MCP",
description="MCP server for the Item API",
base_url="http://localhost:8000",

View File

@@ -1,21 +1,25 @@
from contextlib import asynccontextmanager
from typing import Dict, Optional, Any, List, Union, AsyncIterator
from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, APIRouter
from fastapi.openapi.utils import get_openapi
from mcp.server.lowlevel.server import Server
from mcp.server.sse import SseServerTransport
import mcp.types as types
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
from fastapi_mcp.execute import execute_api_tool
from fastapi_mcp.transport.sse import FastApiSseTransport
from logging import getLogger
logger = getLogger(__name__)
class FastApiMCP:
def __init__(
self,
fastapi: FastAPI,
mount_path: str = "/mcp",
name: Optional[str] = None,
description: Optional[str] = None,
base_url: Optional[str] = None,
@@ -29,7 +33,6 @@ class FastApiMCP:
self.name = name
self.description = description
self._mount_path = mount_path
self._base_url = base_url
self._describe_all_responses = describe_all_responses
self._describe_full_response_schema = describe_full_response_schema
@@ -41,10 +44,12 @@ class FastApiMCP:
Create an MCP server from the FastAPI app.
Args:
app: The FastAPI application
fastapi: The FastAPI application
name: Name for the MCP server (defaults to app.title)
description: Description for the MCP server (defaults to app.description)
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
as the root path would be different when the app is deployed.
describe_all_responses: Whether to include all possible response schemas in tool descriptions
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
@@ -126,37 +131,42 @@ class FastApiMCP:
return mcp_server
def mount(self) -> None:
def mount(self, router: Optional[FastAPI | APIRouter] = None, mount_path: str = "/mcp") -> None:
"""
Mount the MCP server to the FastAPI app.
Args:
app: The FastAPI application
mcp_server: The MCP server to mount
operation_map: A mapping of operation IDs to operation details
router: The FastAPI app or APIRouter to mount the MCP server to. If not provided, the MCP
server will be mounted to the FastAPI app.
mount_path: Path where the MCP server will be mounted
base_url: Base URL for API requests
"""
# Normalize mount path
if not self._mount_path.startswith("/"):
self._mount_path = f"/{self._mount_path}"
if self._mount_path.endswith("/"):
self._mount_path = self._mount_path[:-1]
if not mount_path.startswith("/"):
mount_path = f"/{mount_path}"
if mount_path.endswith("/"):
mount_path = mount_path[:-1]
if not router:
router = self.fastapi
# Create SSE transport for MCP messages
sse_transport = SseServerTransport(f"{self._mount_path}/messages/")
sse_transport = FastApiSseTransport(f"{mount_path}/messages/")
# Define MCP connection handler
# Route for MCP connection
@router.get(mount_path)
async def handle_mcp_connection(request: Request):
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (reader, writer):
await self.mcp_server.run(
streams[0],
streams[1],
reader,
writer,
self.mcp_server.create_initialization_options(
notification_options=None, experimental_capabilities={}
),
)
# Mount the MCP connection handler
self.fastapi.get(self._mount_path)(handle_mcp_connection)
self.fastapi.mount(f"{self._mount_path}/messages/", app=sse_transport.handle_post_message)
# Route for MCP messages
@router.post(f"{mount_path}/messages/")
async def handle_post_message(request: Request):
await sse_transport.handle_fastapi_post_message(request)
logger.info(f"MCP server listening at {mount_path}")

View File

View File

@@ -0,0 +1,72 @@
from uuid import UUID
from logging import getLogger
from fastapi import Request, Response
from pydantic import ValidationError
from mcp.server.sse import SseServerTransport
from mcp.types import JSONRPCMessage
logger = getLogger(__name__)
class FastApiSseTransport(SseServerTransport):
async def handle_fastapi_post_message(self, request: Request) -> None:
"""
A reimplementation of the handle_post_message method of SseServerTransport
that integrates better with FastAPI.
A few good reasons for doing this:
1. Avoid mounting a whole Starlette app and instead use a more FastAPI-native
approach. Mounting has some known issues and limitations.
2. Avoid re-constructing the scope, receive, and send from the request, as done
in the original implementation.
The combination of mounting a whole Starlette app and reconstructing the scope
and send from the request proved to be especially error-prone for us when using
tracing tools like Sentry, which had destructive effects on the request object
when using the original implementation.
"""
logger.debug("Handling POST message")
scope = request.scope
receive = request.receive
send = request._send
session_id_param = request.query_params.get("session_id")
if session_id_param is None:
logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return await response(scope, receive, send)
try:
session_id = UUID(hex=session_id_param)
logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
writer = self._read_stream_writers.get(session_id)
if not writer:
logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)
body = await request.body()
logger.debug(f"Received JSON: {body.decode()}")
try:
message = JSONRPCMessage.model_validate_json(body)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
await writer.send(err)
return
logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)