mirror of
https://github.com/tadata-org/fastapi_mcp.git
synced 2025-04-13 23:32:11 +03:00
fastapi-native routes
This commit is contained in:
@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
|
||||
# Add MCP server to the FastAPI app
|
||||
mcp = FastApiMCP(
|
||||
items.app,
|
||||
mount_path="/mcp",
|
||||
name="Item API MCP",
|
||||
description="MCP server for the Item API",
|
||||
base_url="http://localhost:8000",
|
||||
|
||||
@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
|
||||
# Add MCP server to the FastAPI app
|
||||
mcp = FastApiMCP(
|
||||
items.app,
|
||||
mount_path="/mcp",
|
||||
name="Item API MCP",
|
||||
description="MCP server for the Item API",
|
||||
base_url="http://localhost:8000",
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional, Any, List, Union, AsyncIterator
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, APIRouter
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from mcp.server.lowlevel.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
import mcp.types as types
|
||||
|
||||
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
|
||||
from fastapi_mcp.execute import execute_api_tool
|
||||
from fastapi_mcp.transport.sse import FastApiSseTransport
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FastApiMCP:
|
||||
def __init__(
|
||||
self,
|
||||
fastapi: FastAPI,
|
||||
mount_path: str = "/mcp",
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
@@ -29,7 +33,6 @@ class FastApiMCP:
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
self._mount_path = mount_path
|
||||
self._base_url = base_url
|
||||
self._describe_all_responses = describe_all_responses
|
||||
self._describe_full_response_schema = describe_full_response_schema
|
||||
@@ -41,10 +44,12 @@ class FastApiMCP:
|
||||
Create an MCP server from the FastAPI app.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
fastapi: The FastAPI application
|
||||
name: Name for the MCP server (defaults to app.title)
|
||||
description: Description for the MCP server (defaults to app.description)
|
||||
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
|
||||
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
|
||||
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
|
||||
as the root path would be different when the app is deployed.
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
|
||||
|
||||
@@ -126,37 +131,42 @@ class FastApiMCP:
|
||||
|
||||
return mcp_server
|
||||
|
||||
def mount(self) -> None:
|
||||
def mount(self, router: Optional[FastAPI | APIRouter] = None, mount_path: str = "/mcp") -> None:
|
||||
"""
|
||||
Mount the MCP server to the FastAPI app.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
mcp_server: The MCP server to mount
|
||||
operation_map: A mapping of operation IDs to operation details
|
||||
router: The FastAPI app or APIRouter to mount the MCP server to. If not provided, the MCP
|
||||
server will be mounted to the FastAPI app.
|
||||
mount_path: Path where the MCP server will be mounted
|
||||
base_url: Base URL for API requests
|
||||
"""
|
||||
# Normalize mount path
|
||||
if not self._mount_path.startswith("/"):
|
||||
self._mount_path = f"/{self._mount_path}"
|
||||
if self._mount_path.endswith("/"):
|
||||
self._mount_path = self._mount_path[:-1]
|
||||
if not mount_path.startswith("/"):
|
||||
mount_path = f"/{mount_path}"
|
||||
if mount_path.endswith("/"):
|
||||
mount_path = mount_path[:-1]
|
||||
|
||||
if not router:
|
||||
router = self.fastapi
|
||||
|
||||
# Create SSE transport for MCP messages
|
||||
sse_transport = SseServerTransport(f"{self._mount_path}/messages/")
|
||||
sse_transport = FastApiSseTransport(f"{mount_path}/messages/")
|
||||
|
||||
# Define MCP connection handler
|
||||
# Route for MCP connection
|
||||
@router.get(mount_path)
|
||||
async def handle_mcp_connection(request: Request):
|
||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (reader, writer):
|
||||
await self.mcp_server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
reader,
|
||||
writer,
|
||||
self.mcp_server.create_initialization_options(
|
||||
notification_options=None, experimental_capabilities={}
|
||||
),
|
||||
)
|
||||
|
||||
# Mount the MCP connection handler
|
||||
self.fastapi.get(self._mount_path)(handle_mcp_connection)
|
||||
self.fastapi.mount(f"{self._mount_path}/messages/", app=sse_transport.handle_post_message)
|
||||
# Route for MCP messages
|
||||
@router.post(f"{mount_path}/messages/")
|
||||
async def handle_post_message(request: Request):
|
||||
await sse_transport.handle_fastapi_post_message(request)
|
||||
|
||||
logger.info(f"MCP server listening at {mount_path}")
|
||||
|
||||
0
fastapi_mcp/transport/__init__.py
Normal file
0
fastapi_mcp/transport/__init__.py
Normal file
72
fastapi_mcp/transport/sse.py
Normal file
72
fastapi_mcp/transport/sse.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from uuid import UUID
|
||||
from logging import getLogger
|
||||
|
||||
from fastapi import Request, Response
|
||||
from pydantic import ValidationError
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FastApiSseTransport(SseServerTransport):
|
||||
async def handle_fastapi_post_message(self, request: Request) -> None:
|
||||
"""
|
||||
A reimplementation of the handle_post_message method of SseServerTransport
|
||||
that integrates better with FastAPI.
|
||||
|
||||
A few good reasons for doing this:
|
||||
1. Avoid mounting a whole Starlette app and instead use a more FastAPI-native
|
||||
approach. Mounting has some known issues and limitations.
|
||||
2. Avoid re-constructing the scope, receive, and send from the request, as done
|
||||
in the original implementation.
|
||||
|
||||
The combination of mounting a whole Starlette app and reconstructing the scope
|
||||
and send from the request proved to be especially error-prone for us when using
|
||||
tracing tools like Sentry, which had destructive effects on the request object
|
||||
when using the original implementation.
|
||||
"""
|
||||
|
||||
logger.debug("Handling POST message")
|
||||
scope = request.scope
|
||||
receive = request.receive
|
||||
send = request._send
|
||||
|
||||
session_id_param = request.query_params.get("session_id")
|
||||
if session_id_param is None:
|
||||
logger.warning("Received request without session_id")
|
||||
response = Response("session_id is required", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
try:
|
||||
session_id = UUID(hex=session_id_param)
|
||||
logger.debug(f"Parsed session ID: {session_id}")
|
||||
except ValueError:
|
||||
logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
body = await request.body()
|
||||
logger.debug(f"Received JSON: {body.decode()}")
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(body)
|
||||
logger.debug(f"Validated client message: {message}")
|
||||
except ValidationError as err:
|
||||
logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(err)
|
||||
return
|
||||
|
||||
logger.debug(f"Sending message to writer: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(message)
|
||||
Reference in New Issue
Block a user