mirror of
https://github.com/tadata-org/fastapi_mcp.git
synced 2025-04-13 23:32:11 +03:00
fastapi-native routes
This commit is contained in:
@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
|
|||||||
# Add MCP server to the FastAPI app
|
# Add MCP server to the FastAPI app
|
||||||
mcp = FastApiMCP(
|
mcp = FastApiMCP(
|
||||||
items.app,
|
items.app,
|
||||||
mount_path="/mcp",
|
|
||||||
name="Item API MCP",
|
name="Item API MCP",
|
||||||
description="MCP server for the Item API",
|
description="MCP server for the Item API",
|
||||||
base_url="http://localhost:8000",
|
base_url="http://localhost:8000",
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from fastapi_mcp import FastApiMCP
|
|||||||
# Add MCP server to the FastAPI app
|
# Add MCP server to the FastAPI app
|
||||||
mcp = FastApiMCP(
|
mcp = FastApiMCP(
|
||||||
items.app,
|
items.app,
|
||||||
mount_path="/mcp",
|
|
||||||
name="Item API MCP",
|
name="Item API MCP",
|
||||||
description="MCP server for the Item API",
|
description="MCP server for the Item API",
|
||||||
base_url="http://localhost:8000",
|
base_url="http://localhost:8000",
|
||||||
|
|||||||
@@ -1,21 +1,25 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, Optional, Any, List, Union, AsyncIterator
|
from typing import Dict, Optional, Any, List, Union, AsyncIterator
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request, APIRouter
|
||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
from mcp.server.lowlevel.server import Server
|
from mcp.server.lowlevel.server import Server
|
||||||
from mcp.server.sse import SseServerTransport
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
|
|
||||||
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
|
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
|
||||||
from fastapi_mcp.execute import execute_api_tool
|
from fastapi_mcp.execute import execute_api_tool
|
||||||
|
from fastapi_mcp.transport.sse import FastApiSseTransport
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FastApiMCP:
|
class FastApiMCP:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fastapi: FastAPI,
|
fastapi: FastAPI,
|
||||||
mount_path: str = "/mcp",
|
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
@@ -29,7 +33,6 @@ class FastApiMCP:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
|
|
||||||
self._mount_path = mount_path
|
|
||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
self._describe_all_responses = describe_all_responses
|
self._describe_all_responses = describe_all_responses
|
||||||
self._describe_full_response_schema = describe_full_response_schema
|
self._describe_full_response_schema = describe_full_response_schema
|
||||||
@@ -41,10 +44,12 @@ class FastApiMCP:
|
|||||||
Create an MCP server from the FastAPI app.
|
Create an MCP server from the FastAPI app.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: The FastAPI application
|
fastapi: The FastAPI application
|
||||||
name: Name for the MCP server (defaults to app.title)
|
name: Name for the MCP server (defaults to app.title)
|
||||||
description: Description for the MCP server (defaults to app.description)
|
description: Description for the MCP server (defaults to app.description)
|
||||||
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
|
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
|
||||||
|
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
|
||||||
|
as the root path would be different when the app is deployed.
|
||||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
|
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
|
||||||
|
|
||||||
@@ -126,37 +131,42 @@ class FastApiMCP:
|
|||||||
|
|
||||||
return mcp_server
|
return mcp_server
|
||||||
|
|
||||||
def mount(self) -> None:
|
def mount(self, router: Optional[FastAPI | APIRouter] = None, mount_path: str = "/mcp") -> None:
|
||||||
"""
|
"""
|
||||||
Mount the MCP server to the FastAPI app.
|
Mount the MCP server to the FastAPI app.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: The FastAPI application
|
router: The FastAPI app or APIRouter to mount the MCP server to. If not provided, the MCP
|
||||||
mcp_server: The MCP server to mount
|
server will be mounted to the FastAPI app.
|
||||||
operation_map: A mapping of operation IDs to operation details
|
|
||||||
mount_path: Path where the MCP server will be mounted
|
mount_path: Path where the MCP server will be mounted
|
||||||
base_url: Base URL for API requests
|
|
||||||
"""
|
"""
|
||||||
# Normalize mount path
|
# Normalize mount path
|
||||||
if not self._mount_path.startswith("/"):
|
if not mount_path.startswith("/"):
|
||||||
self._mount_path = f"/{self._mount_path}"
|
mount_path = f"/{mount_path}"
|
||||||
if self._mount_path.endswith("/"):
|
if mount_path.endswith("/"):
|
||||||
self._mount_path = self._mount_path[:-1]
|
mount_path = mount_path[:-1]
|
||||||
|
|
||||||
|
if not router:
|
||||||
|
router = self.fastapi
|
||||||
|
|
||||||
# Create SSE transport for MCP messages
|
# Create SSE transport for MCP messages
|
||||||
sse_transport = SseServerTransport(f"{self._mount_path}/messages/")
|
sse_transport = FastApiSseTransport(f"{mount_path}/messages/")
|
||||||
|
|
||||||
# Define MCP connection handler
|
# Route for MCP connection
|
||||||
|
@router.get(mount_path)
|
||||||
async def handle_mcp_connection(request: Request):
|
async def handle_mcp_connection(request: Request):
|
||||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
|
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (reader, writer):
|
||||||
await self.mcp_server.run(
|
await self.mcp_server.run(
|
||||||
streams[0],
|
reader,
|
||||||
streams[1],
|
writer,
|
||||||
self.mcp_server.create_initialization_options(
|
self.mcp_server.create_initialization_options(
|
||||||
notification_options=None, experimental_capabilities={}
|
notification_options=None, experimental_capabilities={}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mount the MCP connection handler
|
# Route for MCP messages
|
||||||
self.fastapi.get(self._mount_path)(handle_mcp_connection)
|
@router.post(f"{mount_path}/messages/")
|
||||||
self.fastapi.mount(f"{self._mount_path}/messages/", app=sse_transport.handle_post_message)
|
async def handle_post_message(request: Request):
|
||||||
|
await sse_transport.handle_fastapi_post_message(request)
|
||||||
|
|
||||||
|
logger.info(f"MCP server listening at {mount_path}")
|
||||||
|
|||||||
0
fastapi_mcp/transport/__init__.py
Normal file
0
fastapi_mcp/transport/__init__.py
Normal file
72
fastapi_mcp/transport/sse.py
Normal file
72
fastapi_mcp/transport/sse.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from mcp.server.sse import SseServerTransport
|
||||||
|
from mcp.types import JSONRPCMessage
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FastApiSseTransport(SseServerTransport):
|
||||||
|
async def handle_fastapi_post_message(self, request: Request) -> None:
|
||||||
|
"""
|
||||||
|
A reimplementation of the handle_post_message method of SseServerTransport
|
||||||
|
that integrates better with FastAPI.
|
||||||
|
|
||||||
|
A few good reasons for doing this:
|
||||||
|
1. Avoid mounting a whole Starlette app and instead use a more FastAPI-native
|
||||||
|
approach. Mounting has some known issues and limitations.
|
||||||
|
2. Avoid re-constructing the scope, receive, and send from the request, as done
|
||||||
|
in the original implementation.
|
||||||
|
|
||||||
|
The combination of mounting a whole Starlette app and reconstructing the scope
|
||||||
|
and send from the request proved to be especially error-prone for us when using
|
||||||
|
tracing tools like Sentry, which had destructive effects on the request object
|
||||||
|
when using the original implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.debug("Handling POST message")
|
||||||
|
scope = request.scope
|
||||||
|
receive = request.receive
|
||||||
|
send = request._send
|
||||||
|
|
||||||
|
session_id_param = request.query_params.get("session_id")
|
||||||
|
if session_id_param is None:
|
||||||
|
logger.warning("Received request without session_id")
|
||||||
|
response = Response("session_id is required", status_code=400)
|
||||||
|
return await response(scope, receive, send)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session_id = UUID(hex=session_id_param)
|
||||||
|
logger.debug(f"Parsed session ID: {session_id}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||||
|
response = Response("Invalid session ID", status_code=400)
|
||||||
|
return await response(scope, receive, send)
|
||||||
|
|
||||||
|
writer = self._read_stream_writers.get(session_id)
|
||||||
|
if not writer:
|
||||||
|
logger.warning(f"Could not find session for ID: {session_id}")
|
||||||
|
response = Response("Could not find session", status_code=404)
|
||||||
|
return await response(scope, receive, send)
|
||||||
|
|
||||||
|
body = await request.body()
|
||||||
|
logger.debug(f"Received JSON: {body.decode()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = JSONRPCMessage.model_validate_json(body)
|
||||||
|
logger.debug(f"Validated client message: {message}")
|
||||||
|
except ValidationError as err:
|
||||||
|
logger.error(f"Failed to parse message: {err}")
|
||||||
|
response = Response("Could not parse message", status_code=400)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
await writer.send(err)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Sending message to writer: {message}")
|
||||||
|
response = Response("Accepted", status_code=202)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
await writer.send(message)
|
||||||
Reference in New Issue
Block a user