ensure message path is built correctly, and introduce a hacky fix for mounting on APIRouter

This commit is contained in:
Shahar Abramov
2025-04-10 14:35:34 +03:00
parent ca7b1f4644
commit a6433f1fc2
2 changed files with 22 additions and 5 deletions

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel
class LoggingConfig(BaseModel):
LOGGER_NAME: str = "fastapi_mcp"
LOG_FORMAT: str = "%(levelprefix)s %(asctime)s\t[%(name)s] %(message)s"
LOG_LEVEL: str = logging.getLevelName(logging.INFO)
LOG_LEVEL: str = logging.getLevelName(logging.DEBUG)
version: int = 1
disable_existing_loggers: bool = False

View File

@@ -159,11 +159,20 @@ class FastApiMCP:
if not router:
router = self.fastapi
# Create SSE transport for MCP messages
sse_transport = FastApiSseTransport(f"{mount_path}/messages/")
# Build the base path correctly for the SSE transport
if isinstance(router, FastAPI):
base_path = router.root_path
elif isinstance(router, APIRouter):
base_path = self.fastapi.root_path + router.prefix
else:
raise ValueError(f"Invalid router type: {type(router)}")
messages_path = f"{base_path}{mount_path}/messages/"
sse_transport = FastApiSseTransport(messages_path)
# Route for MCP connection
@router.get(mount_path, include_in_schema=False)
@router.get(mount_path, include_in_schema=False, operation_id="mcp_connection")
async def handle_mcp_connection(request: Request):
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (reader, writer):
await self.server.run(
@@ -173,10 +182,18 @@ class FastApiMCP:
)
# Route for MCP messages
@router.post(f"{mount_path}/messages/", include_in_schema=False)
@router.post(f"{mount_path}/messages/", include_in_schema=False, operation_id="mcp_messages")
async def handle_post_message(request: Request):
return await sse_transport.handle_fastapi_post_message(request)
# HACK: If we got a router and not a FastAPI instance, we need to re-include the router so that
# FastAPI will pick up the new routes we added. The problem with this approach is that we assume
# that the router is a sub-router of self.fastapi, which may not always be the case.
#
# TODO: Find a better way to do this.
if isinstance(router, APIRouter):
self.fastapi.include_router(router)
logger.info(f"MCP server listening at {mount_path}")
async def _execute_api_tool(