mirror of
https://github.com/tadata-org/fastapi_mcp.git
synced 2025-04-13 23:32:11 +03:00
refactor
This commit is contained in:
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
0
examples/apps/__init__.py
Normal file
0
examples/apps/__init__.py
Normal file
@@ -6,8 +6,6 @@ from fastapi import FastAPI, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi_mcp import add_mcp_server
|
||||
|
||||
|
||||
# Create a simple FastAPI app
|
||||
app = FastAPI(
|
||||
@@ -31,7 +29,7 @@ items_db: dict[int, Item] = {}
|
||||
|
||||
|
||||
# Define some endpoints
|
||||
@app.get("/items/", response_model=List[Item], tags=["items"])
|
||||
@app.get("/items/", response_model=List[Item], tags=["items"], operation_id="list_items")
|
||||
async def list_items(skip: int = 0, limit: int = 10):
|
||||
"""
|
||||
List all items in the database.
|
||||
@@ -41,7 +39,7 @@ async def list_items(skip: int = 0, limit: int = 10):
|
||||
return list(items_db.values())[skip : skip + limit]
|
||||
|
||||
|
||||
@app.get("/items/{item_id}", response_model=Item, tags=["items"])
|
||||
@app.get("/items/{item_id}", response_model=Item, tags=["items"], operation_id="get_item")
|
||||
async def read_item(item_id: int):
|
||||
"""
|
||||
Get a specific item by its ID.
|
||||
@@ -53,7 +51,7 @@ async def read_item(item_id: int):
|
||||
return items_db[item_id]
|
||||
|
||||
|
||||
@app.post("/items/", response_model=Item, tags=["items"])
|
||||
@app.post("/items/", response_model=Item, tags=["items"], operation_id="create_item")
|
||||
async def create_item(item: Item):
|
||||
"""
|
||||
Create a new item in the database.
|
||||
@@ -64,7 +62,7 @@ async def create_item(item: Item):
|
||||
return item
|
||||
|
||||
|
||||
@app.put("/items/{item_id}", response_model=Item, tags=["items"])
|
||||
@app.put("/items/{item_id}", response_model=Item, tags=["items"], operation_id="update_item")
|
||||
async def update_item(item_id: int, item: Item):
|
||||
"""
|
||||
Update an existing item.
|
||||
@@ -79,7 +77,7 @@ async def update_item(item_id: int, item: Item):
|
||||
return item
|
||||
|
||||
|
||||
@app.delete("/items/{item_id}", tags=["items"])
|
||||
@app.delete("/items/{item_id}", tags=["items"], operation_id="delete_item")
|
||||
async def delete_item(item_id: int):
|
||||
"""
|
||||
Delete an item from the database.
|
||||
@@ -93,7 +91,7 @@ async def delete_item(item_id: int):
|
||||
return {"message": "Item deleted successfully"}
|
||||
|
||||
|
||||
@app.get("/items/search/", response_model=List[Item], tags=["search"])
|
||||
@app.get("/items/search/", response_model=List[Item], tags=["search"], operation_id="search_items")
|
||||
async def search_items(
|
||||
q: Optional[str] = Query(None, description="Search query string"),
|
||||
min_price: Optional[float] = Query(None, description="Minimum price"),
|
||||
@@ -135,32 +133,5 @@ sample_items = [
|
||||
Item(id=4, name="Saw", description="A tool for cutting wood", price=19.99, tags=["tool", "hardware", "cutting"]),
|
||||
Item(id=5, name="Drill", description="A tool for drilling holes", price=49.99, tags=["tool", "hardware", "power"]),
|
||||
]
|
||||
|
||||
for item in sample_items:
|
||||
items_db[item.id] = item
|
||||
|
||||
|
||||
# Add MCP server to the FastAPI app
|
||||
mcp_server = add_mcp_server(
|
||||
app,
|
||||
mount_path="/mcp",
|
||||
name="Item API MCP",
|
||||
description="MCP server for the Item API",
|
||||
base_url="http://localhost:8000",
|
||||
describe_all_responses=False, # Only describe the success response in tool descriptions
|
||||
describe_full_response_schema=False, # Only show LLM-friendly example response in tool descriptions, not the full json schema
|
||||
)
|
||||
|
||||
|
||||
# Optionally, you can add custom MCP tools not based on FastAPI endpoints
|
||||
@mcp_server.tool()
|
||||
async def get_item_count() -> int:
|
||||
"""Get the total number of items in the database."""
|
||||
return len(items_db)
|
||||
|
||||
|
||||
# Run the server if this file is executed directly
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
22
examples/simple_example.py
Normal file
22
examples/simple_example.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Simple example of using FastAPI-MCP to add an MCP server to a FastAPI app.
|
||||
"""
|
||||
|
||||
from examples.apps import items
|
||||
|
||||
from fastapi_mcp import add_mcp_server
|
||||
|
||||
|
||||
# Add MCP server to the FastAPI app
|
||||
mcp = add_mcp_server(
|
||||
items.app,
|
||||
mount_path="/mcp",
|
||||
name="Item API MCP",
|
||||
description="MCP server for the Item API",
|
||||
base_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(items.app, host="0.0.0.0", port=8000)
|
||||
@@ -13,15 +13,14 @@ except Exception:
|
||||
__version__ = "0.0.0.dev0"
|
||||
|
||||
from .server import add_mcp_server, create_mcp_server, mount_mcp_server
|
||||
from .mcp_tools import (
|
||||
convert_openapi_to_mcp_tools,
|
||||
execute_http_tool,
|
||||
)
|
||||
from .openapi.convert import convert_openapi_to_mcp_tools
|
||||
from .execute import execute_api_tool
|
||||
|
||||
|
||||
__all__ = [
|
||||
"add_mcp_server",
|
||||
"create_mcp_server",
|
||||
"mount_mcp_server",
|
||||
"convert_openapi_to_mcp_tools",
|
||||
"execute_http_tool",
|
||||
"execute_api_tool",
|
||||
]
|
||||
|
||||
101
fastapi_mcp/execute.py
Normal file
101
fastapi_mcp/execute.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Direct OpenAPI to MCP Tools Conversion Module.
|
||||
|
||||
This module provides functionality for directly converting OpenAPI schema to MCP tool specifications
|
||||
and for executing HTTP tools.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
|
||||
logger = logging.getLogger("fastapi_mcp")
|
||||
|
||||
|
||||
async def execute_api_tool(
|
||||
base_url: str, tool_name: str, arguments: Dict[str, Any], operation_map: Dict[str, Dict[str, Any]]
|
||||
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
|
||||
"""
|
||||
Execute an MCP tool by making an HTTP request to the corresponding API endpoint.
|
||||
|
||||
Args:
|
||||
base_url: The base URL for the API
|
||||
tool_name: The name of the tool to execute
|
||||
arguments: The arguments for the tool
|
||||
operation_map: A mapping from tool names to operation details
|
||||
|
||||
Returns:
|
||||
The result as MCP content types
|
||||
"""
|
||||
if tool_name not in operation_map:
|
||||
return [types.TextContent(type="text", text=f"Unknown tool: {tool_name}")]
|
||||
|
||||
operation = operation_map[tool_name]
|
||||
path: str = operation["path"]
|
||||
method: str = operation["method"]
|
||||
parameters: List[Dict[str, Any]] = operation.get("parameters", [])
|
||||
|
||||
# Deep copy arguments to avoid modifying the original
|
||||
kwargs = arguments.copy() if arguments else {}
|
||||
|
||||
# Prepare URL with path parameters
|
||||
url = f"{base_url}{path}"
|
||||
for param in parameters:
|
||||
if param.get("in") == "path" and param.get("name") in kwargs:
|
||||
param_name = param.get("name", None)
|
||||
if param_name is None:
|
||||
raise ValueError(f"Parameter name is None for parameter: {param}")
|
||||
url = url.replace(f"{{{param_name}}}", str(kwargs.pop(param_name)))
|
||||
|
||||
# Prepare query parameters
|
||||
query = {}
|
||||
for param in parameters:
|
||||
if param.get("in") == "query" and param.get("name") in kwargs:
|
||||
param_name = param.get("name", None)
|
||||
if param_name is None:
|
||||
raise ValueError(f"Parameter name is None for parameter: {param}")
|
||||
query[param_name] = kwargs.pop(param_name)
|
||||
|
||||
# Prepare headers
|
||||
headers = {}
|
||||
for param in parameters:
|
||||
if param.get("in") == "header" and param.get("name") in kwargs:
|
||||
param_name = param.get("name", None)
|
||||
if param_name is None:
|
||||
raise ValueError(f"Parameter name is None for parameter: {param}")
|
||||
headers[param_name] = kwargs.pop(param_name)
|
||||
|
||||
# Prepare request body (remaining kwargs)
|
||||
body = kwargs if kwargs else None
|
||||
|
||||
try:
|
||||
# Make the request
|
||||
logger.debug(f"Making {method.upper()} request to {url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
if method.lower() == "get":
|
||||
response = await client.get(url, params=query, headers=headers)
|
||||
elif method.lower() == "post":
|
||||
response = await client.post(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "put":
|
||||
response = await client.put(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "delete":
|
||||
response = await client.delete(url, params=query, headers=headers)
|
||||
elif method.lower() == "patch":
|
||||
response = await client.patch(url, params=query, headers=headers, json=body)
|
||||
else:
|
||||
return [types.TextContent(type="text", text=f"Unsupported HTTP method: {method}")]
|
||||
|
||||
# Process the response
|
||||
try:
|
||||
result = response.json()
|
||||
return [types.TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||
except ValueError:
|
||||
return [types.TextContent(type="text", text=response.text)]
|
||||
|
||||
except Exception as e:
|
||||
return [types.TextContent(type="text", text=f"Error calling {tool_name}: {str(e)}")]
|
||||
@@ -1,458 +0,0 @@
|
||||
"""
|
||||
HTTP tools for FastAPI-MCP.
|
||||
|
||||
This module provides functionality for creating MCP tools from FastAPI endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .openapi_utils import (
|
||||
clean_schema_for_display,
|
||||
generate_example_from_schema,
|
||||
get_python_type_and_default,
|
||||
get_single_param_type_from_schema,
|
||||
resolve_schema_references,
|
||||
PYTHON_TYPE_IMPORTS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("fastapi_mcp")
|
||||
|
||||
|
||||
def create_mcp_tools_from_openapi(
|
||||
app: FastAPI,
|
||||
mcp_server: FastMCP,
|
||||
base_url: Optional[str] = None,
|
||||
describe_all_responses: bool = False,
|
||||
describe_full_response_schema: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Create MCP tools from a FastAPI app's OpenAPI schema.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
mcp_server: The MCP server to add tools to
|
||||
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full response schema in tool descriptions
|
||||
"""
|
||||
# Get OpenAPI schema from FastAPI app
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Resolve all references in the schema at once
|
||||
resolved_openapi_schema = resolve_schema_references(openapi_schema, openapi_schema)
|
||||
|
||||
if not base_url:
|
||||
# Try to determine the base URL from FastAPI config
|
||||
if hasattr(app, "root_path") and app.root_path:
|
||||
base_url = app.root_path
|
||||
else:
|
||||
# Default to localhost with FastAPI default port
|
||||
port = 8000
|
||||
for route in app.routes:
|
||||
if hasattr(route, "app") and hasattr(route.app, "port"):
|
||||
port = route.app.port
|
||||
break
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
# Normalize base URL
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
# Process each path in the OpenAPI schema
|
||||
for path, path_item in resolved_openapi_schema.get("paths", {}).items():
|
||||
for method, operation in path_item.items():
|
||||
# Skip non-HTTP methods
|
||||
if method not in ["get", "post", "put", "delete", "patch"]:
|
||||
continue
|
||||
|
||||
# Get operation metadata
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
continue
|
||||
|
||||
# Create MCP tool for this operation
|
||||
create_http_tool(
|
||||
mcp_server=mcp_server,
|
||||
base_url=base_url,
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=operation.get("summary", ""),
|
||||
description=operation.get("description", ""),
|
||||
parameters=operation.get("parameters", []),
|
||||
request_body=operation.get("requestBody", {}),
|
||||
responses=operation.get("responses", {}),
|
||||
openapi_schema=resolved_openapi_schema,
|
||||
describe_all_responses=describe_all_responses,
|
||||
describe_full_response_schema=describe_full_response_schema,
|
||||
)
|
||||
|
||||
def _create_http_tool_function(function_template: Callable, properties: Dict[str, Any], additional_variables: Dict[str, Any]) -> Callable:
|
||||
# Build parameter string with type hints
|
||||
parsed_parameters = {}
|
||||
parsed_parameters_with_defaults = {}
|
||||
for param_name, parsed_param_schema in properties.items():
|
||||
type_hint, has_default_value = get_python_type_and_default(parsed_param_schema)
|
||||
if has_default_value:
|
||||
parsed_parameters_with_defaults[param_name] = f"{param_name}: {type_hint}"
|
||||
else:
|
||||
parsed_parameters[param_name] = f"{param_name}: {type_hint}"
|
||||
|
||||
parsed_parameters_keys = list(parsed_parameters.keys()) + list(parsed_parameters_with_defaults.keys())
|
||||
parsed_parameters_values = list(parsed_parameters.values()) + list(parsed_parameters_with_defaults.values())
|
||||
parameters_str = ", ".join(parsed_parameters_values)
|
||||
kwargs_str = ', '.join([f"'{k}': {k}" for k in parsed_parameters_keys])
|
||||
|
||||
dynamic_function_body = f"""async def dynamic_http_tool_function({parameters_str}):
|
||||
kwargs = {{{kwargs_str}}}
|
||||
return await http_tool_function_template(**kwargs)
|
||||
"""
|
||||
|
||||
# Create function namespace with required imports
|
||||
namespace = {
|
||||
"http_tool_function_template": function_template,
|
||||
**PYTHON_TYPE_IMPORTS,
|
||||
**additional_variables
|
||||
}
|
||||
|
||||
# Execute the dynamic function definition
|
||||
exec(dynamic_function_body, namespace)
|
||||
return namespace["dynamic_http_tool_function"]
|
||||
|
||||
def create_http_tool(
|
||||
mcp_server: FastMCP,
|
||||
base_url: str,
|
||||
path: str,
|
||||
method: str,
|
||||
operation_id: str,
|
||||
summary: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
request_body: Dict[str, Any],
|
||||
responses: Dict[str, Any],
|
||||
openapi_schema: Dict[str, Any],
|
||||
describe_all_responses: bool,
|
||||
describe_full_response_schema: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Create an MCP tool that makes an HTTP request to a FastAPI endpoint.
|
||||
|
||||
Args:
|
||||
mcp_server: The MCP server to add the tool to
|
||||
base_url: Base URL for API requests
|
||||
path: API endpoint path
|
||||
method: HTTP method
|
||||
operation_id: Operation ID
|
||||
summary: Operation summary
|
||||
description: Operation description
|
||||
parameters: OpenAPI parameters
|
||||
request_body: OpenAPI request body
|
||||
responses: OpenAPI responses
|
||||
openapi_schema: The full OpenAPI schema
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full response schema in tool descriptions
|
||||
"""
|
||||
# Build tool description
|
||||
tool_description = f"{summary}" if summary else f"{method.upper()} {path}"
|
||||
if description:
|
||||
tool_description += f"\n\n{description}"
|
||||
|
||||
# Add response schema information to description
|
||||
if responses:
|
||||
response_info = "\n\n### Responses:\n"
|
||||
|
||||
# Find the success response (usually 200 or 201)
|
||||
success_codes = ["200", "201", "202", 200, 201, 202]
|
||||
success_response = None
|
||||
for status_code in success_codes:
|
||||
if str(status_code) in responses:
|
||||
success_response = responses[str(status_code)]
|
||||
break
|
||||
|
||||
# Get the list of responses to include
|
||||
responses_to_include = responses
|
||||
if not describe_all_responses and success_response:
|
||||
# If we're not describing all responses, only include the success response
|
||||
success_code = next((code for code in success_codes if str(code) in responses), None)
|
||||
if success_code:
|
||||
responses_to_include = {str(success_code): success_response}
|
||||
|
||||
# Process all selected responses
|
||||
for status_code, response_data in responses_to_include.items():
|
||||
response_desc = response_data.get("description", "")
|
||||
response_info += f"\n**{status_code}**: {response_desc}"
|
||||
|
||||
# Highlight if this is the main success response
|
||||
if response_data == success_response:
|
||||
response_info += " (Success Response)"
|
||||
|
||||
# Add schema information if available
|
||||
if "content" in response_data:
|
||||
for content_type, content_data in response_data["content"].items():
|
||||
if "schema" in content_data:
|
||||
schema = content_data["schema"]
|
||||
response_info += f"\nContent-Type: {content_type}"
|
||||
|
||||
# Clean the schema for display
|
||||
display_schema = clean_schema_for_display(schema)
|
||||
|
||||
# Get model name if it's a referenced model
|
||||
model_name = None
|
||||
model_examples = None
|
||||
items_model_name = None
|
||||
|
||||
# Check if this is an array of items
|
||||
if schema.get("type") == "array" and "items" in schema and "$ref" in schema["items"]:
|
||||
items_ref_path = schema["items"]["$ref"]
|
||||
if items_ref_path.startswith("#/components/schemas/"):
|
||||
items_model_name = items_ref_path.split("/")[-1]
|
||||
response_info += f"\nArray of: {items_model_name}"
|
||||
|
||||
# Create example response based on schema type
|
||||
example_response = None
|
||||
|
||||
# Check if we have examples from the model
|
||||
if model_examples and len(model_examples) > 0:
|
||||
example_response = model_examples[0] # Use first example
|
||||
# Otherwise, try to create an example from the response definitions
|
||||
elif "examples" in response_data:
|
||||
# Use examples directly from response definition
|
||||
for example_key, example_data in response_data["examples"].items():
|
||||
if "value" in example_data:
|
||||
example_response = example_data["value"]
|
||||
break
|
||||
# If content has examples
|
||||
elif "examples" in content_data:
|
||||
for example_key, example_data in content_data["examples"].items():
|
||||
if "value" in example_data:
|
||||
example_response = example_data["value"]
|
||||
break
|
||||
# If content has example
|
||||
elif "example" in content_data:
|
||||
example_response = content_data["example"]
|
||||
|
||||
# Special handling for array of items
|
||||
if (
|
||||
not example_response
|
||||
and display_schema.get("type") == "array"
|
||||
and items_model_name == "Item"
|
||||
):
|
||||
example_response = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Hammer",
|
||||
"description": "A tool for hammering nails",
|
||||
"price": 9.99,
|
||||
"tags": ["tool", "hardware"],
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Screwdriver",
|
||||
"description": "A tool for driving screws",
|
||||
"price": 7.99,
|
||||
"tags": ["tool", "hardware"],
|
||||
},
|
||||
] # type: ignore
|
||||
|
||||
# If we have an example response, add it to the docs
|
||||
if example_response:
|
||||
response_info += "\n\n**Example Response:**\n```json\n"
|
||||
response_info += json.dumps(example_response, indent=2)
|
||||
response_info += "\n```"
|
||||
# Otherwise generate an example from the schema
|
||||
else:
|
||||
generated_example = generate_example_from_schema(display_schema, model_name)
|
||||
if generated_example:
|
||||
response_info += "\n\n**Example Response:**\n```json\n"
|
||||
response_info += json.dumps(generated_example, indent=2)
|
||||
response_info += "\n```"
|
||||
|
||||
# Only include full schema information if requested
|
||||
if describe_full_response_schema:
|
||||
# Format schema information based on its type
|
||||
if display_schema.get("type") == "array" and "items" in display_schema:
|
||||
items_schema = display_schema["items"]
|
||||
|
||||
response_info += (
|
||||
"\n\n**Output Schema:** Array of items with the following structure:\n```json\n"
|
||||
)
|
||||
response_info += json.dumps(items_schema, indent=2)
|
||||
response_info += "\n```"
|
||||
elif "properties" in display_schema:
|
||||
response_info += "\n\n**Output Schema:**\n```json\n"
|
||||
response_info += json.dumps(display_schema, indent=2)
|
||||
response_info += "\n```"
|
||||
else:
|
||||
response_info += "\n\n**Output Schema:**\n```json\n"
|
||||
response_info += json.dumps(display_schema, indent=2)
|
||||
response_info += "\n```"
|
||||
|
||||
tool_description += response_info
|
||||
|
||||
# Organize parameters by type
|
||||
path_params = []
|
||||
query_params = []
|
||||
header_params = []
|
||||
body_params = []
|
||||
for param in parameters:
|
||||
param_name = param.get("name")
|
||||
param_in = param.get("in")
|
||||
required = param.get("required", False)
|
||||
|
||||
if param_in == "path":
|
||||
path_params.append((param_name, param))
|
||||
elif param_in == "query":
|
||||
query_params.append((param_name, param))
|
||||
elif param_in == "header":
|
||||
header_params.append((param_name, param))
|
||||
|
||||
# Process request body if present
|
||||
if request_body and "content" in request_body:
|
||||
content_type = next(iter(request_body["content"]), None)
|
||||
if content_type and "schema" in request_body["content"][content_type]:
|
||||
schema = request_body["content"][content_type]["schema"]
|
||||
if "properties" in schema:
|
||||
for prop_name, prop_schema in schema["properties"].items():
|
||||
required = prop_name in schema.get("required", [])
|
||||
body_params.append(
|
||||
(
|
||||
prop_name,
|
||||
{
|
||||
"name": prop_name,
|
||||
"schema": prop_schema,
|
||||
"required": required,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Create input schema properties for all parameters
|
||||
properties = {}
|
||||
required_props = []
|
||||
|
||||
# Add path parameters to properties
|
||||
for param_name, param in path_params:
|
||||
param_schema = param.get("schema", {})
|
||||
param_desc = param.get("description", "")
|
||||
param_required = param.get("required", True) # Path params are usually required
|
||||
|
||||
properties[param_name] = {
|
||||
"type": param_schema.get("type", "string"),
|
||||
"title": param_name,
|
||||
"description": param_desc,
|
||||
}
|
||||
|
||||
if param_required:
|
||||
required_props.append(param_name)
|
||||
|
||||
# Add query parameters to properties
|
||||
for param_name, param in query_params:
|
||||
param_schema = param.get("schema", {})
|
||||
param_desc = param.get("description", "")
|
||||
param_required = param.get("required", False)
|
||||
|
||||
properties[param_name] = {
|
||||
"type": get_single_param_type_from_schema(param_schema),
|
||||
"title": param_name,
|
||||
"description": param_desc,
|
||||
}
|
||||
if "default" in param_schema:
|
||||
properties[param_name]["default"] = param_schema["default"]
|
||||
|
||||
if param_required:
|
||||
required_props.append(param_name)
|
||||
|
||||
# Add body parameters to properties
|
||||
for param_name, param in body_params:
|
||||
param_schema = param.get("schema", {})
|
||||
param_required = param.get("required", False)
|
||||
|
||||
# properties[param_name] = param_schema
|
||||
properties[param_name] = {
|
||||
"type": get_single_param_type_from_schema(param_schema),
|
||||
"title": param_name,
|
||||
}
|
||||
if "default" in param_schema:
|
||||
properties[param_name]["default"] = param_schema["default"]
|
||||
|
||||
if param_required:
|
||||
required_props.append(param_name)
|
||||
|
||||
# Create a proper input schema for the tool
|
||||
input_schema = {"type": "object", "properties": properties, "title": f"{operation_id}Arguments"}
|
||||
|
||||
if required_props:
|
||||
input_schema["required"] = required_props
|
||||
|
||||
# Dynamically create a function to call the API endpoint
|
||||
async def http_tool_function_template(**kwargs):
|
||||
# Prepare URL with path parameters
|
||||
url = f"{base_url}{path}"
|
||||
for param_name, _ in path_params:
|
||||
if param_name in kwargs:
|
||||
url = url.replace(f"{{{param_name}}}", str(kwargs.pop(param_name)))
|
||||
|
||||
# Prepare query parameters
|
||||
query = {}
|
||||
for param_name, _ in query_params:
|
||||
if param_name in kwargs:
|
||||
query[param_name] = kwargs.pop(param_name)
|
||||
|
||||
# Prepare headers
|
||||
headers = {}
|
||||
for param_name, _ in header_params:
|
||||
if param_name in kwargs:
|
||||
headers[param_name] = kwargs.pop(param_name)
|
||||
|
||||
# Prepare request body (remaining kwargs)
|
||||
body = kwargs if kwargs else None
|
||||
|
||||
# Make the request
|
||||
logger.debug(f"Making {method.upper()} request to {url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
if method.lower() == "get":
|
||||
response = await client.get(url, params=query, headers=headers)
|
||||
elif method.lower() == "post":
|
||||
response = await client.post(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "put":
|
||||
response = await client.put(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "delete":
|
||||
response = await client.delete(url, params=query, headers=headers)
|
||||
elif method.lower() == "patch":
|
||||
response = await client.patch(url, params=query, headers=headers, json=body)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# Process the response
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError:
|
||||
return response.text
|
||||
|
||||
# Create the http_tool_function (with name and docstring)
|
||||
additional_variables = {"path_params": path_params, "query_params": query_params, "header_params": header_params}
|
||||
http_tool_function = _create_http_tool_function(http_tool_function_template, properties, additional_variables) # type: ignore
|
||||
http_tool_function.__name__ = operation_id
|
||||
http_tool_function.__doc__ = tool_description
|
||||
|
||||
# Monkey patch the function's schema for MCP tool creation
|
||||
# TODO: Maybe revise this hacky approach
|
||||
http_tool_function._input_schema = input_schema # type: ignore
|
||||
|
||||
# Add tool to the MCP server with the enhanced schema
|
||||
tool = mcp_server._tool_manager.add_tool(http_tool_function, name=operation_id, description=tool_description)
|
||||
|
||||
# Update the tool's parameters to use our custom schema instead of the auto-generated one
|
||||
tool.parameters = input_schema
|
||||
0
fastapi_mcp/openapi/__init__.py
Normal file
0
fastapi_mcp/openapi/__init__.py
Normal file
@@ -2,22 +2,16 @@
|
||||
Direct OpenAPI to MCP Tools Conversion Module.
|
||||
|
||||
This module provides functionality for directly converting OpenAPI schema to MCP tool specifications
|
||||
without the intermediate step of dynamically generating Python functions.
|
||||
and for executing HTTP tools.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple, AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel.server import Server
|
||||
|
||||
from .openapi_utils import (
|
||||
from .utils import (
|
||||
clean_schema_for_display,
|
||||
generate_example_from_schema,
|
||||
resolve_schema_references,
|
||||
@@ -266,268 +260,3 @@ def convert_openapi_to_mcp_tools(
|
||||
tools.append(tool)
|
||||
|
||||
return tools, operation_map
|
||||
|
||||
|
||||
async def execute_http_tool(
|
||||
base_url: str, tool_name: str, arguments: Dict[str, Any], operation_map: Dict[str, Dict[str, Any]]
|
||||
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
|
||||
"""
|
||||
Execute an MCP tool by making an HTTP request to the corresponding API endpoint.
|
||||
|
||||
Args:
|
||||
base_url: The base URL for the API
|
||||
tool_name: The name of the tool to execute
|
||||
arguments: The arguments for the tool
|
||||
operation_map: A mapping from tool names to operation details
|
||||
|
||||
Returns:
|
||||
The result as MCP content types
|
||||
"""
|
||||
if tool_name not in operation_map:
|
||||
return [types.TextContent(type="text", text=f"Unknown tool: {tool_name}")]
|
||||
|
||||
operation = operation_map[tool_name]
|
||||
path = operation["path"]
|
||||
method = operation["method"]
|
||||
parameters = operation.get("parameters", [])
|
||||
|
||||
# Deep copy arguments to avoid modifying the original
|
||||
kwargs = arguments.copy() if arguments else {}
|
||||
|
||||
# Prepare URL with path parameters
|
||||
url = f"{base_url}{path}"
|
||||
for param in parameters:
|
||||
if param.get("in") == "path" and param.get("name") in kwargs:
|
||||
param_name = param.get("name")
|
||||
url = url.replace(f"{{{param_name}}}", str(kwargs.pop(param_name)))
|
||||
|
||||
# Prepare query parameters
|
||||
query = {}
|
||||
for param in parameters:
|
||||
if param.get("in") == "query" and param.get("name") in kwargs:
|
||||
param_name = param.get("name")
|
||||
query[param_name] = kwargs.pop(param_name)
|
||||
|
||||
# Prepare headers
|
||||
headers = {}
|
||||
for param in parameters:
|
||||
if param.get("in") == "header" and param.get("name") in kwargs:
|
||||
param_name = param.get("name")
|
||||
headers[param_name] = kwargs.pop(param_name)
|
||||
|
||||
# Prepare request body (remaining kwargs)
|
||||
body = kwargs if kwargs else None
|
||||
|
||||
try:
|
||||
# Make the request
|
||||
logger.debug(f"Making {method.upper()} request to {url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
if method.lower() == "get":
|
||||
response = await client.get(url, params=query, headers=headers)
|
||||
elif method.lower() == "post":
|
||||
response = await client.post(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "put":
|
||||
response = await client.put(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "delete":
|
||||
response = await client.delete(url, params=query, headers=headers)
|
||||
elif method.lower() == "patch":
|
||||
response = await client.patch(url, params=query, headers=headers, json=body)
|
||||
else:
|
||||
return [types.TextContent(type="text", text=f"Unsupported HTTP method: {method}")]
|
||||
|
||||
# Process the response
|
||||
try:
|
||||
result = response.json()
|
||||
return [types.TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||
except ValueError:
|
||||
return [types.TextContent(type="text", text=response.text)]
|
||||
|
||||
except Exception as e:
|
||||
return [types.TextContent(type="text", text=f"Error calling {tool_name}: {str(e)}")]
|
||||
|
||||
|
||||
def create_mcp_server(
|
||||
app: FastAPI,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
describe_all_responses: bool = False,
|
||||
describe_full_response_schema: bool = False,
|
||||
) -> tuple[Server, Dict[str, Dict[str, Any]]]:
|
||||
"""
|
||||
Create a low-level MCP server from a FastAPI app using direct OpenAPI to MCP conversion.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
name: Name for the MCP server (defaults to app.title)
|
||||
description: Description for the MCP server (defaults to app.description)
|
||||
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full response schema in tool descriptions
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The MCP Server instance (NOT mounted to the app)
|
||||
- A mapping of operation IDs to operation details for HTTP execution
|
||||
"""
|
||||
# Use app details if not provided
|
||||
server_name = name or app.title or "FastAPI MCP"
|
||||
server_description = description or app.description
|
||||
|
||||
# Get OpenAPI schema from FastAPI app
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Convert OpenAPI operations to MCP tools
|
||||
tools, operation_map = convert_openapi_to_mcp_tools(
|
||||
openapi_schema,
|
||||
describe_all_responses=describe_all_responses,
|
||||
describe_full_response_schema=describe_full_response_schema,
|
||||
)
|
||||
|
||||
# Determine base URL if not provided
|
||||
if not base_url:
|
||||
# Try to determine the base URL from FastAPI config
|
||||
if hasattr(app, "root_path") and app.root_path:
|
||||
base_url = app.root_path
|
||||
else:
|
||||
# Default to localhost with FastAPI default port
|
||||
port = 8000
|
||||
for route in app.routes:
|
||||
if hasattr(route, "app") and hasattr(route.app, "port"):
|
||||
port = route.app.port
|
||||
break
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
# Normalize base URL
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
# Create the MCP server
|
||||
mcp_server: Server = Server(server_name, server_description)
|
||||
|
||||
# Create a lifespan context manager to store the base_url and operation_map
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]:
|
||||
# Store context data that will be available to all server handlers
|
||||
context = {"base_url": base_url, "operation_map": operation_map}
|
||||
yield context
|
||||
|
||||
# Use our custom lifespan
|
||||
mcp_server.lifespan = server_lifespan
|
||||
|
||||
# Register handlers for tools
|
||||
@mcp_server.list_tools()
|
||||
async def handle_list_tools() -> List[types.Tool]:
|
||||
"""Handler for the tools/list request"""
|
||||
return tools
|
||||
|
||||
# Register the tool call handler
|
||||
@mcp_server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
|
||||
"""Handler for the tools/call request"""
|
||||
# Get context from server lifespan
|
||||
ctx = mcp_server.request_context
|
||||
base_url = ctx.lifespan_context["base_url"]
|
||||
operation_map = ctx.lifespan_context["operation_map"]
|
||||
|
||||
# Execute the tool
|
||||
return await execute_http_tool(base_url, name, arguments, operation_map)
|
||||
|
||||
return mcp_server, operation_map
|
||||
|
||||
|
||||
def mount_mcp_server(
|
||||
app: FastAPI,
|
||||
mcp_server: Server,
|
||||
operation_map: Dict[str, Dict[str, Any]],
|
||||
mount_path: str = "/mcp",
|
||||
base_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Mount an MCP server to a FastAPI app.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
mcp_server: The MCP server to mount
|
||||
operation_map: A mapping of operation IDs to operation details
|
||||
mount_path: Path where the MCP server will be mounted
|
||||
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
|
||||
"""
|
||||
# Normalize mount path
|
||||
if not mount_path.startswith("/"):
|
||||
mount_path = f"/{mount_path}"
|
||||
if mount_path.endswith("/"):
|
||||
mount_path = mount_path[:-1]
|
||||
|
||||
# Create SSE transport for MCP messages
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from fastapi import Request
|
||||
|
||||
sse_transport = SseServerTransport(f"{mount_path}/messages/")
|
||||
|
||||
# Define MCP connection handler
|
||||
async def handle_mcp_connection(request: Request):
|
||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await mcp_server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
mcp_server.create_initialization_options(notification_options=None, experimental_capabilities={}),
|
||||
)
|
||||
|
||||
# Mount the MCP connection handler
|
||||
app.get(mount_path)(handle_mcp_connection)
|
||||
app.mount(f"{mount_path}/messages/", app=sse_transport.handle_post_message)
|
||||
|
||||
|
||||
def add_mcp_server(
|
||||
app: FastAPI,
|
||||
mount_path: str = "/mcp",
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
describe_all_responses: bool = False,
|
||||
describe_full_response_schema: bool = False,
|
||||
) -> Server:
|
||||
"""
|
||||
Add an MCP server to a FastAPI app.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
mount_path: Path where the MCP server will be mounted
|
||||
name: Name for the MCP server
|
||||
description: Description for the MCP server
|
||||
base_url: Base URL for API requests
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full response schema in tool descriptions
|
||||
|
||||
Returns:
|
||||
The MCP server instance
|
||||
"""
|
||||
# Create MCP server
|
||||
mcp_server, operation_map = create_mcp_server(
|
||||
app,
|
||||
name,
|
||||
description,
|
||||
base_url,
|
||||
describe_all_responses=describe_all_responses,
|
||||
describe_full_response_schema=describe_full_response_schema,
|
||||
)
|
||||
|
||||
# Mount MCP server to FastAPI app
|
||||
mount_mcp_server(
|
||||
app,
|
||||
mcp_server,
|
||||
operation_map,
|
||||
mount_path,
|
||||
base_url,
|
||||
)
|
||||
|
||||
return mcp_server
|
||||
@@ -10,15 +10,15 @@ from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
PYTHON_TYPE_IMPORTS = {
|
||||
"List": List,
|
||||
"Dict": Dict,
|
||||
"Any": Any,
|
||||
"Optional": Optional,
|
||||
"Union": Union,
|
||||
"date": date,
|
||||
"datetime": datetime,
|
||||
"Decimal": Decimal,
|
||||
"UUID": UUID,
|
||||
"List": List,
|
||||
"Dict": Dict,
|
||||
"Any": Any,
|
||||
"Optional": Optional,
|
||||
"Union": Union,
|
||||
"date": date,
|
||||
"datetime": datetime,
|
||||
"Decimal": Decimal,
|
||||
"UUID": UUID,
|
||||
}
|
||||
# Type mapping from OpenAPI types to Python types
|
||||
OPENAPI_PYTHON_TYPES_MAP = {
|
||||
@@ -28,18 +28,15 @@ OPENAPI_PYTHON_TYPES_MAP = {
|
||||
"integer": "int",
|
||||
"boolean": "bool",
|
||||
"null": "None",
|
||||
|
||||
# Complex types
|
||||
"object": "Dict[str, Any]", # More specific than Dict[Any, Any]
|
||||
"array": "List[Any]",
|
||||
|
||||
# Numeric formats
|
||||
"int32": "int",
|
||||
"int64": "int",
|
||||
"float": "float",
|
||||
"double": "float",
|
||||
"decimal": "Decimal",
|
||||
|
||||
# String formats - Common
|
||||
"date": "date", # datetime.date
|
||||
"date-time": "datetime", # datetime.datetime
|
||||
@@ -48,7 +45,6 @@ OPENAPI_PYTHON_TYPES_MAP = {
|
||||
"password": "str",
|
||||
"byte": "bytes", # base64 encoded
|
||||
"binary": "bytes", # raw binary
|
||||
|
||||
# String formats - Extended
|
||||
"email": "str",
|
||||
"uuid": "UUID", # uuid.UUID
|
||||
@@ -62,36 +58,29 @@ OPENAPI_PYTHON_TYPES_MAP = {
|
||||
"regex": "str",
|
||||
"json-pointer": "str",
|
||||
"relative-json-pointer": "str",
|
||||
|
||||
# Rich text formats
|
||||
"markdown": "str",
|
||||
"html": "str",
|
||||
|
||||
# Media types
|
||||
"image/*": "bytes",
|
||||
"audio/*": "bytes",
|
||||
"video/*": "bytes",
|
||||
"application/*": "bytes",
|
||||
|
||||
# Special formats
|
||||
"format": "str", # Custom format string
|
||||
"pattern": "str", # Regular expression pattern
|
||||
"contentEncoding": "str", # e.g., base64, quoted-printable
|
||||
"contentMediaType": "str", # MIME type
|
||||
|
||||
# Additional numeric formats
|
||||
"currency": "Decimal", # For precise decimal calculations
|
||||
"percentage": "float",
|
||||
|
||||
# Geographic coordinates
|
||||
"latitude": "float",
|
||||
"longitude": "float",
|
||||
|
||||
# Time-related
|
||||
"timezone": "str", # Could use zoneinfo.ZoneInfo in Python 3.9+
|
||||
"unix-time": "int", # Unix timestamp
|
||||
"iso-week-date": "str", # ISO 8601 week date
|
||||
|
||||
# Specialized string formats
|
||||
"isbn": "str",
|
||||
"issn": "str",
|
||||
@@ -102,14 +91,14 @@ OPENAPI_PYTHON_TYPES_MAP = {
|
||||
"language-code": "str", # ISO 639 language codes
|
||||
"country-code": "str", # ISO 3166 country codes
|
||||
"currency-code": "str", # ISO 4217 currency codes
|
||||
|
||||
# Default fallback
|
||||
"unknown": "Any"
|
||||
"unknown": "Any",
|
||||
}
|
||||
|
||||
|
||||
def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Get the type of a parameter from the schema.
|
||||
Get the type of a parameter from the schema.
|
||||
If the schema is a union type, return the first type.
|
||||
"""
|
||||
if "anyOf" in param_schema:
|
||||
@@ -121,6 +110,7 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
|
||||
return "string"
|
||||
return param_schema.get("type", "string")
|
||||
|
||||
|
||||
def get_python_type_and_default(parsed_param_schema: Dict[str, Any]) -> tuple[str, bool]:
|
||||
"""
|
||||
Parse parameters into a python type and default value string.
|
||||
@@ -131,10 +121,11 @@ def get_python_type_and_default(parsed_param_schema: Dict[str, Any]) -> tuple[st
|
||||
"""
|
||||
# Handle direct type specification
|
||||
python_type = OPENAPI_PYTHON_TYPES_MAP.get(parsed_param_schema.get("type", ""), "Any")
|
||||
if "default" in parsed_param_schema:
|
||||
if "default" in parsed_param_schema:
|
||||
return f"{python_type} = {parsed_param_schema.get('default')}", True
|
||||
return python_type, False
|
||||
|
||||
|
||||
def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Resolve schema references in OpenAPI schemas.
|
||||
@@ -170,13 +161,12 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
|
||||
elif isinstance(value, list):
|
||||
# Only process list items that are dictionaries since only they can contain refs
|
||||
schema_part[key] = [
|
||||
resolve_schema_references(item, reference_schema) if isinstance(item, dict)
|
||||
else item
|
||||
for item in value
|
||||
resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value
|
||||
]
|
||||
|
||||
return schema_part
|
||||
|
||||
|
||||
def clean_schema_for_display(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Clean up a schema for display by removing internal fields.
|
||||
@@ -219,6 +209,7 @@ def clean_schema_for_display(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def extract_model_examples_from_components(
|
||||
model_name: str, openapi_schema: Dict[str, Any]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
@@ -252,6 +243,7 @@ def extract_model_examples_from_components(
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def generate_example_from_schema(schema: Dict[str, Any], model_name: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Generate a simple example response from a JSON schema.
|
||||
@@ -328,4 +320,4 @@ def generate_example_from_schema(schema: Dict[str, Any], model_name: Optional[st
|
||||
return None
|
||||
|
||||
# Default case
|
||||
return None
|
||||
return None
|
||||
@@ -4,56 +4,121 @@ Server module for FastAPI-MCP.
|
||||
This module provides functionality for creating and mounting MCP servers to FastAPI applications.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional, Any, Tuple, List, Union, AsyncIterator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from mcp.server.lowlevel.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
import mcp.types as types
|
||||
|
||||
from .mcp_tools import create_mcp_server as direct_create_mcp_server
|
||||
from .mcp_tools import mount_mcp_server as direct_mount_mcp_server
|
||||
from .mcp_tools import add_mcp_server as direct_add_mcp_server
|
||||
from .openapi.convert import convert_openapi_to_mcp_tools
|
||||
from .execute import execute_api_tool
|
||||
|
||||
|
||||
def create_mcp_server(
|
||||
app: FastAPI,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
capabilities: Optional[Dict[str, Any]] = None,
|
||||
base_url: Optional[str] = None,
|
||||
describe_all_responses: bool = False,
|
||||
describe_full_response_schema: bool = False,
|
||||
) -> Server:
|
||||
) -> Tuple[Server, Dict[str, Dict[str, Any]]]:
|
||||
"""
|
||||
Create an MCP server from a FastAPI app using direct OpenAPI to MCP conversion.
|
||||
Create an MCP server from a FastAPI app.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
name: Name for the MCP server (defaults to app.title)
|
||||
description: Description for the MCP server (defaults to app.description)
|
||||
capabilities: Optional capabilities for the MCP server (ignored in direct conversion)
|
||||
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions. Recommended to keep False, as the LLM will probably derive if there is an error.
|
||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. Recommended to keep False, as examples are more LLM friendly, and save tokens.
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
|
||||
|
||||
Returns:
|
||||
The created MCP Server instance (NOT mounted to the app)
|
||||
A tuple containing:
|
||||
- The created MCP Server instance (NOT mounted to the app)
|
||||
- A mapping of operation IDs to operation details for HTTP execution
|
||||
"""
|
||||
# Use direct conversion (returns a tuple of server and operation_map)
|
||||
server_tuple = direct_create_mcp_server(
|
||||
app,
|
||||
name,
|
||||
description,
|
||||
base_url,
|
||||
# Get OpenAPI schema from FastAPI app
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Get server name and description from app if not provided
|
||||
server_name = name or app.title or "FastAPI MCP"
|
||||
server_description = description or app.description
|
||||
|
||||
# Convert OpenAPI schema to MCP tools
|
||||
tools, operation_map = convert_openapi_to_mcp_tools(
|
||||
openapi_schema,
|
||||
describe_all_responses=describe_all_responses,
|
||||
describe_full_response_schema=describe_full_response_schema,
|
||||
)
|
||||
# Extract just the server from the tuple
|
||||
return server_tuple[0]
|
||||
|
||||
# Determine base URL if not provided
|
||||
if not base_url:
|
||||
# Try to determine the base URL from FastAPI config
|
||||
if hasattr(app, "root_path") and app.root_path:
|
||||
base_url = app.root_path
|
||||
else:
|
||||
# Default to localhost with FastAPI default port
|
||||
port = 8000
|
||||
for route in app.routes:
|
||||
if hasattr(route, "app") and hasattr(route.app, "port"):
|
||||
port = route.app.port
|
||||
break
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
# Normalize base URL
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
# Create the MCP server
|
||||
mcp_server: Server = Server(server_name, server_description)
|
||||
|
||||
# Create a lifespan context manager to store the base_url and operation_map
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]:
|
||||
# Store context data that will be available to all server handlers
|
||||
context = {"base_url": base_url, "operation_map": operation_map}
|
||||
yield context
|
||||
|
||||
# Use our custom lifespan
|
||||
mcp_server.lifespan = server_lifespan
|
||||
|
||||
# Register handlers for tools
|
||||
@mcp_server.list_tools()
|
||||
async def handle_list_tools() -> List[types.Tool]:
|
||||
"""Handler for the tools/list request"""
|
||||
return tools
|
||||
|
||||
# Register the tool call handler
|
||||
@mcp_server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
|
||||
"""Handler for the tools/call request"""
|
||||
# Get context from server lifespan
|
||||
ctx = mcp_server.request_context
|
||||
base_url = ctx.lifespan_context["base_url"]
|
||||
operation_map = ctx.lifespan_context["operation_map"]
|
||||
|
||||
# Execute the tool
|
||||
return await execute_api_tool(base_url, name, arguments, operation_map)
|
||||
|
||||
return mcp_server, operation_map
|
||||
|
||||
|
||||
def mount_mcp_server(
|
||||
app: FastAPI,
|
||||
mcp_server: Server,
|
||||
operation_map: Dict[str, Dict[str, Any]],
|
||||
mount_path: str = "/mcp",
|
||||
base_url: Optional[str] = None,
|
||||
) -> None:
|
||||
@@ -63,28 +128,31 @@ def mount_mcp_server(
|
||||
Args:
|
||||
app: The FastAPI application
|
||||
mcp_server: The MCP server to mount
|
||||
operation_map: A mapping of operation IDs to operation details
|
||||
mount_path: Path where the MCP server will be mounted
|
||||
base_url: Base URL for API requests
|
||||
"""
|
||||
# Get OpenAPI schema from FastAPI app for operation mapping
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from .mcp_tools import convert_openapi_to_mcp_tools
|
||||
# Normalize mount path
|
||||
if not mount_path.startswith("/"):
|
||||
mount_path = f"/{mount_path}"
|
||||
if mount_path.endswith("/"):
|
||||
mount_path = mount_path[:-1]
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
# Create SSE transport for MCP messages
|
||||
sse_transport = SseServerTransport(f"{mount_path}/messages/")
|
||||
|
||||
# Extract operation map for HTTP calls
|
||||
# The function returns a tuple (tools, operation_map)
|
||||
result = convert_openapi_to_mcp_tools(openapi_schema)
|
||||
operation_map = result[1]
|
||||
# Define MCP connection handler
|
||||
async def handle_mcp_connection(request: Request):
|
||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await mcp_server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
mcp_server.create_initialization_options(notification_options=None, experimental_capabilities={}),
|
||||
)
|
||||
|
||||
# Mount using the direct approach
|
||||
direct_mount_mcp_server(app, mcp_server, operation_map, mount_path, base_url)
|
||||
# Mount the MCP connection handler
|
||||
app.get(mount_path)(handle_mcp_connection)
|
||||
app.mount(f"{mount_path}/messages/", app=sse_transport.handle_post_message)
|
||||
|
||||
|
||||
def add_mcp_server(
|
||||
@@ -105,19 +173,23 @@ def add_mcp_server(
|
||||
name: Name for the MCP server (defaults to app.title)
|
||||
description: Description for the MCP server (defaults to app.description)
|
||||
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions. Recommended to keep False, as the LLM will probably derive if there is an error.
|
||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. Recommended to keep False, as examples are more LLM friendly, and save tokens.
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
|
||||
|
||||
Returns:
|
||||
The MCP server instance that was created and mounted
|
||||
"""
|
||||
# Use direct conversion approach
|
||||
return direct_add_mcp_server(
|
||||
# Create MCP server
|
||||
mcp_server, operation_map = create_mcp_server(
|
||||
app,
|
||||
mount_path,
|
||||
name,
|
||||
description,
|
||||
base_url,
|
||||
describe_all_responses=describe_all_responses,
|
||||
describe_full_response_schema=describe_full_response_schema,
|
||||
)
|
||||
|
||||
# Mount MCP server
|
||||
mount_mcp_server(app, mcp_server, operation_map, mount_path, base_url)
|
||||
|
||||
return mcp_server
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
# FastAPI-MCP Test Suite
|
||||
|
||||
This directory contains automated tests for the FastAPI-MCP library.
|
||||
|
||||
## Test Files
|
||||
|
||||
- `test_tool_generation.py`: Tests the basic functionality of generating MCP tools from FastAPI endpoints
|
||||
- `test_http_tools.py`: Tests the core HTTP tools module that converts FastAPI endpoints to MCP tools
|
||||
- `test_server.py`: Tests the server module for creating and mounting MCP servers
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the tests, make sure you have installed the development dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
Then run the tests with pytest:
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest
|
||||
|
||||
# Run with coverage report
|
||||
pytest --cov=fastapi_mcp
|
||||
|
||||
# Run a specific test file
|
||||
pytest tests/test_tool_generation.py
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
Each test file follows this general structure:
|
||||
|
||||
1. **Fixtures**: Define test fixtures for creating sample FastAPI applications
|
||||
2. **Unit Tests**: Individual test functions that verify specific aspects of the library
|
||||
3. **Integration Tests**: Tests that verify components work together correctly
|
||||
|
||||
## Adding New Tests
|
||||
|
||||
When adding new tests, follow these guidelines:
|
||||
|
||||
1. Create a test function with a clear name that indicates what functionality it's testing
|
||||
2. Use descriptive assertions that explain what is being tested
|
||||
3. Keep tests focused on a single aspect of functionality
|
||||
4. Use fixtures to avoid code duplication
|
||||
|
||||
## Manual Testing
|
||||
|
||||
In addition to these automated tests, manual testing can be performed using the `test_mcp_tools.py` script in the project root. This script connects to a running MCP server, initializes a session, and requests a list of available tools.
|
||||
|
||||
To run the manual test:
|
||||
|
||||
1. Start your FastAPI app with an MCP server
|
||||
2. Run the test script:
|
||||
|
||||
```bash
|
||||
python test_mcp_tools.py http://localhost:8000/mcp
|
||||
```
|
||||
|
||||
The script will output the results of each request for manual inspection.
|
||||
@@ -5,6 +5,66 @@ Contains fixtures and settings for the test suite.
|
||||
|
||||
import sys
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import FastAPI, Query, Path, Body
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
# Add the parent directory to the path so that we can import the local package
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
price: float
|
||||
tags: List[str] = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fastapi_app():
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
description="A test API app for unit testing",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
@app.get("/items/", response_model=List[Item], tags=["items"], operation_id="list_items")
|
||||
async def list_items(
|
||||
skip: int = Query(0, description="Number of items to skip"),
|
||||
limit: int = Query(10, description="Max number of items to return"),
|
||||
sort_by: Optional[str] = Query(None, description="Field to sort by"),
|
||||
):
|
||||
"""List all items with pagination and sorting options."""
|
||||
return []
|
||||
|
||||
@app.get("/items/{item_id}", response_model=Item, tags=["items"], operation_id="get_item")
|
||||
async def read_item(
|
||||
item_id: int = Path(..., description="The ID of the item to retrieve"),
|
||||
include_details: bool = Query(False, description="Include additional details"),
|
||||
):
|
||||
"""Get a specific item by its ID with optional details."""
|
||||
return {"id": item_id, "name": "Test Item", "price": 10.0}
|
||||
|
||||
@app.post("/items/", response_model=Item, tags=["items"], operation_id="create_item")
|
||||
async def create_item(item: Item = Body(..., description="The item to create")):
|
||||
"""Create a new item in the database."""
|
||||
return item
|
||||
|
||||
@app.put("/items/{item_id}", response_model=Item, tags=["items"], operation_id="update_item")
|
||||
async def update_item(
|
||||
item_id: int = Path(..., description="The ID of the item to update"),
|
||||
item: Item = Body(..., description="The updated item data"),
|
||||
):
|
||||
"""Update an existing item."""
|
||||
item.id = item_id
|
||||
return item
|
||||
|
||||
@app.delete("/items/{item_id}", tags=["items"], operation_id="delete_item")
|
||||
async def delete_item(item_id: int = Path(..., description="The ID of the item to delete")):
|
||||
"""Delete an item from the database."""
|
||||
return {"message": "Item deleted successfully"}
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script to verify that an MCP server is properly exposing tools.
|
||||
This script connects to the MCP server, initializes a session, and requests a list of available tools.
|
||||
"""
|
||||
|
||||
# TODO: Turn this into a pytest test
|
||||
|
||||
import json
|
||||
import sys
|
||||
import asyncio
|
||||
import httpx
|
||||
from urllib.parse import urljoin
|
||||
|
||||
# Default MCP server URL
|
||||
MCP_URL = "http://localhost:8000/mcp"
|
||||
|
||||
|
||||
async def test_mcp_tools(url=MCP_URL):
|
||||
"""Connect to the MCP server and test tool exposure."""
|
||||
print(f"Connecting to MCP server at {url}...")
|
||||
|
||||
# Connect to the SSE endpoint to establish connection
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client:
|
||||
# First establish an SSE connection
|
||||
endpoint_url = None
|
||||
base_url = url.rsplit("/", 1)[0] + "/" # Extract base URL (everything up to the last path component)
|
||||
|
||||
response_queue = asyncio.Queue()
|
||||
|
||||
# Task to send requests and receive responses through the SSE channel
|
||||
async def send_request(request_data, request_id):
|
||||
# Send the request
|
||||
await client.post(endpoint_url, json=request_data)
|
||||
print(f"Sent request with ID: {request_id}")
|
||||
|
||||
# Wait for the response with matching ID
|
||||
while True:
|
||||
response = await response_queue.get()
|
||||
if "id" in response and response["id"] == request_id:
|
||||
return response
|
||||
else:
|
||||
# Not our response, put it back in the queue for someone else
|
||||
await response_queue.put(response)
|
||||
|
||||
# Start the SSE connection
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
print("Connected to MCP server")
|
||||
|
||||
# Process the SSE stream
|
||||
current_event = None
|
||||
|
||||
async def process_sse_stream():
|
||||
nonlocal current_event, endpoint_url
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("event:"):
|
||||
current_event = line[len("event:") :].strip()
|
||||
print(f"Received event: {current_event}")
|
||||
elif line.startswith("data:"):
|
||||
data = line[len("data:") :].strip()
|
||||
|
||||
if current_event == "endpoint":
|
||||
endpoint_path = data
|
||||
endpoint_url = urljoin(base_url, endpoint_path.lstrip("/"))
|
||||
print(f"Endpoint URL: {endpoint_url}")
|
||||
elif current_event == "message":
|
||||
try:
|
||||
message = json.loads(data)
|
||||
# Pretty print the JSON message
|
||||
print("Received message:")
|
||||
print(json.dumps(message, indent=2))
|
||||
|
||||
# Add to queue for request handlers
|
||||
await response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Failed to parse message: {data}")
|
||||
|
||||
# Start processing the SSE stream in the background
|
||||
background_task = asyncio.create_task(process_sse_stream())
|
||||
|
||||
# Wait for the endpoint URL to be set
|
||||
while endpoint_url is None:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
# 1. Initialize request
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"experimental": {}},
|
||||
"clientInfo": {"name": "mcp-test-client", "version": "0.1.0"},
|
||||
},
|
||||
}
|
||||
|
||||
print("\nSending initialize request...")
|
||||
init_result = await send_request(init_request, 1)
|
||||
print("\nInitialization response:")
|
||||
print(json.dumps(init_result, indent=2))
|
||||
|
||||
# 2. Send initialized notification
|
||||
init_notification = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
||||
|
||||
await client.post(endpoint_url, json=init_notification)
|
||||
print("\nSent initialized notification")
|
||||
|
||||
# 3. List tools request
|
||||
list_tools_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}
|
||||
|
||||
print("\nSending tools/list request...")
|
||||
tools_result = await send_request(list_tools_request, 2)
|
||||
print("\nTools list response:")
|
||||
print(json.dumps(tools_result, indent=2))
|
||||
|
||||
# Check if we got a valid response
|
||||
if "result" in tools_result and "tools" in tools_result["result"]:
|
||||
tools = tools_result["result"]["tools"]
|
||||
if tools:
|
||||
print(f"\nFound {len(tools)} tools:")
|
||||
for i, tool in enumerate(tools):
|
||||
print(f"{i + 1}. {tool['name']}")
|
||||
print(f"{tool.get('description', 'No description')}")
|
||||
|
||||
# 4. Find and call the get_item_count tool
|
||||
get_item_count_tool = None
|
||||
for tool in tools:
|
||||
if tool["name"] == "get_item_count":
|
||||
get_item_count_tool = tool
|
||||
break
|
||||
|
||||
if get_item_count_tool:
|
||||
print(f"\nTrying to call tool: {get_item_count_tool['name']}")
|
||||
|
||||
call_tool_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": get_item_count_tool["name"],
|
||||
"arguments": {}, # Empty arguments for get_item_count which doesn't require any
|
||||
},
|
||||
}
|
||||
|
||||
print("\nSending tools/call request...")
|
||||
tool_result = await send_request(call_tool_request, 3)
|
||||
print("\nTool call response:")
|
||||
print(json.dumps(tool_result, indent=2))
|
||||
else:
|
||||
print("\nCould not find get_item_count tool")
|
||||
else:
|
||||
print("\nNo tools found")
|
||||
else:
|
||||
print("\nInvalid tools/list response format")
|
||||
finally:
|
||||
# Clean up and cancel the background task
|
||||
background_task.cancel()
|
||||
try:
|
||||
await background_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
url = sys.argv[1] if len(sys.argv) > 1 else MCP_URL
|
||||
asyncio.run(test_mcp_tools(url))
|
||||
@@ -1,175 +0,0 @@
|
||||
"""
|
||||
Tests for the fastapi_mcp http_tools module.
|
||||
This tests the conversion of FastAPI endpoints to MCP tools.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Query, Path, Body
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi_mcp import add_mcp_server
|
||||
from fastapi_mcp.http_tools import (
|
||||
resolve_schema_references,
|
||||
clean_schema_for_display,
|
||||
)
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
price: float
|
||||
tags: List[str] = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_app():
|
||||
"""Create a more complex FastAPI app for testing HTTP tool generation."""
|
||||
app = FastAPI(
|
||||
title="Complex API",
|
||||
description="A complex API with various endpoint types for testing",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
@app.get("/items/", response_model=List[Item], tags=["items"])
|
||||
async def list_items(
|
||||
skip: int = Query(0, description="Number of items to skip"),
|
||||
limit: int = Query(10, description="Max number of items to return"),
|
||||
sort_by: Optional[str] = Query(None, description="Field to sort by"),
|
||||
):
|
||||
"""List all items with pagination and sorting options."""
|
||||
return []
|
||||
|
||||
@app.get("/items/{item_id}", response_model=Item, tags=["items"])
|
||||
async def read_item(
|
||||
item_id: int = Path(..., description="The ID of the item to retrieve"),
|
||||
include_details: bool = Query(False, description="Include additional details"),
|
||||
):
|
||||
"""Get a specific item by its ID with optional details."""
|
||||
return {"id": item_id, "name": "Test Item", "price": 10.0}
|
||||
|
||||
@app.post("/items/", response_model=Item, tags=["items"], status_code=201)
|
||||
async def create_item(item: Item = Body(..., description="The item to create")):
|
||||
"""Create a new item in the database."""
|
||||
return item
|
||||
|
||||
@app.put("/items/{item_id}", response_model=Item, tags=["items"])
|
||||
async def update_item(
|
||||
item_id: int = Path(..., description="The ID of the item to update"),
|
||||
item: Item = Body(..., description="The updated item data"),
|
||||
):
|
||||
"""Update an existing item."""
|
||||
item.id = item_id
|
||||
return item
|
||||
|
||||
@app.delete("/items/{item_id}", tags=["items"])
|
||||
async def delete_item(item_id: int = Path(..., description="The ID of the item to delete")):
|
||||
"""Delete an item from the database."""
|
||||
return {"message": "Item deleted successfully"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_resolve_schema_references():
|
||||
"""Test resolving schema references in OpenAPI schemas."""
|
||||
# Create a schema with references
|
||||
test_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {"$ref": "#/components/schemas/Item"},
|
||||
"items": {"type": "array", "items": {"$ref": "#/components/schemas/Item"}},
|
||||
},
|
||||
}
|
||||
|
||||
# Create a simple OpenAPI schema with the reference
|
||||
openapi_schema = {
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Item": {"type": "object", "properties": {"id": {"type": "integer"}, "name": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Resolve references
|
||||
resolved_schema = resolve_schema_references(test_schema, openapi_schema)
|
||||
|
||||
# Verify the references were resolved
|
||||
assert "$ref" not in resolved_schema["properties"]["item"], "Reference should be resolved"
|
||||
assert "type" in resolved_schema["properties"]["item"], "Reference should be replaced with actual schema"
|
||||
assert "$ref" not in resolved_schema["properties"]["items"]["items"], "Array item reference should be resolved"
|
||||
|
||||
|
||||
def test_clean_schema_for_display():
|
||||
"""Test cleaning schema for display by removing internal fields."""
|
||||
test_schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
"nullable": True, # Should be removed
|
||||
"readOnly": True, # Should be removed
|
||||
"writeOnly": False, # Should be removed
|
||||
"externalDocs": {"url": "https://example.com"}, # Should be removed
|
||||
}
|
||||
|
||||
cleaned_schema = clean_schema_for_display(test_schema)
|
||||
|
||||
# Verify internal fields were removed
|
||||
assert "nullable" not in cleaned_schema, "Internal field 'nullable' should be removed"
|
||||
assert "readOnly" not in cleaned_schema, "Internal field 'readOnly' should be removed"
|
||||
assert "writeOnly" not in cleaned_schema, "Internal field 'writeOnly' should be removed"
|
||||
assert "externalDocs" not in cleaned_schema, "Internal field 'externalDocs' should be removed"
|
||||
|
||||
# Verify important fields are preserved
|
||||
assert "type" in cleaned_schema, "Important field 'type' should be preserved"
|
||||
assert "properties" in cleaned_schema, "Important field 'properties' should be preserved"
|
||||
|
||||
|
||||
def test_create_mcp_tools_from_complex_app(complex_app):
|
||||
"""Test creating MCP tools from a complex FastAPI app."""
|
||||
# Create MCP server and register tools
|
||||
mcp_server = add_mcp_server(complex_app, serve_tools=True, base_url="http://localhost:8000")
|
||||
|
||||
# Extract tools from server for inspection
|
||||
tools = mcp_server._tool_manager.list_tools()
|
||||
|
||||
# Excluding the MCP endpoint handler that might be included
|
||||
api_tools = [
|
||||
t for t in tools if t.name.startswith(("list_items", "read_item", "create_item", "update_item", "delete_item"))
|
||||
]
|
||||
|
||||
# Verify we have the expected number of API tools
|
||||
assert len(api_tools) == 5, f"Expected 5 API tools, got {len(api_tools)}"
|
||||
|
||||
# Check for all expected tools with the correct name pattern
|
||||
tool_operations = ["list_items", "read_item", "create_item", "update_item", "delete_item"]
|
||||
for operation in tool_operations:
|
||||
matching_tools = [t for t in tools if operation in t.name]
|
||||
assert len(matching_tools) > 0, f"No tool found for operation '{operation}'"
|
||||
|
||||
# Verify POST tool has correct status code in description
|
||||
create_tool = next((t for t in tools if "create_item" in t.name), None)
|
||||
assert "201" in create_tool.description or "Created" in create_tool.description, (
|
||||
"Expected status code 201 in create_item description"
|
||||
)
|
||||
|
||||
# Verify path params are correctly handled
|
||||
read_tool = next((t for t in tools if "read_item" in t.name), None)
|
||||
assert "item_id" in read_tool.parameters["properties"], "Expected path parameter 'item_id'"
|
||||
assert "required" in read_tool.parameters, "Parameters should have 'required' field"
|
||||
assert "item_id" in read_tool.parameters["required"], "Path parameter should be required"
|
||||
|
||||
# Verify query params are correctly handled
|
||||
list_tool = next((t for t in tools if "list_items" in t.name), None)
|
||||
assert "skip" in list_tool.parameters["properties"], "Expected query parameter 'skip'"
|
||||
assert "limit" in list_tool.parameters["properties"], "Expected query parameter 'limit'"
|
||||
assert "sort_by" in list_tool.parameters["properties"], "Expected query parameter 'sort_by'"
|
||||
|
||||
# Check if required field exists before testing it
|
||||
if "required" in list_tool.parameters:
|
||||
assert "skip" not in list_tool.parameters["required"], "Optional parameter should not be required"
|
||||
else:
|
||||
# If there's no required field, then skip is implicitly optional
|
||||
pass
|
||||
|
||||
# We'll skip checking the body parameter in the update tool as it seems
|
||||
# the implementation handles it differently than we expected
|
||||
@@ -1,172 +0,0 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi_mcp import add_mcp_server
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
price: float
|
||||
tags: List[str] = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_app():
|
||||
"""Create a sample FastAPI app for testing."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
description="A test API for unit testing",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
@app.get("/items/", response_model=List[Item], tags=["items"])
|
||||
async def list_items(skip: int = 0, limit: int = 10):
|
||||
"""
|
||||
List all items.
|
||||
|
||||
Returns a list of items, with pagination support.
|
||||
"""
|
||||
return []
|
||||
|
||||
@app.get("/items/{item_id}", response_model=Item, tags=["items"])
|
||||
async def read_item(item_id: int):
|
||||
"""
|
||||
Get a specific item by ID.
|
||||
|
||||
Returns the item with the specified ID.
|
||||
"""
|
||||
return {"id": item_id, "name": "Test Item", "price": 0.0}
|
||||
|
||||
@app.post("/items/", response_model=Item, tags=["items"])
|
||||
async def create_item(item: Item):
|
||||
"""
|
||||
Create a new item.
|
||||
|
||||
Returns the created item.
|
||||
"""
|
||||
return item
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_tool_generation_basic(sample_app):
|
||||
"""Test that MCP tools are properly generated with default settings."""
|
||||
# Create MCP server and register tools
|
||||
mcp_server = add_mcp_server(sample_app, serve_tools=True, base_url="http://localhost:8000")
|
||||
|
||||
# Extract tools for inspection
|
||||
tools = mcp_server._tool_manager.list_tools()
|
||||
|
||||
# Tool count may include the MCP endpoint itself, so check for at least the API endpoints
|
||||
assert len(tools) >= 3, f"Expected at least 3 tools, got {len(tools)}"
|
||||
|
||||
# Check each tool has required properties
|
||||
for tool in tools:
|
||||
assert hasattr(tool, "name"), "Tool missing 'name' property"
|
||||
assert hasattr(tool, "description"), "Tool missing 'description' property"
|
||||
assert hasattr(tool, "parameters"), "Tool missing 'parameters' property"
|
||||
assert hasattr(tool, "fn_metadata"), "Tool missing 'fn_metadata' property"
|
||||
|
||||
# With describe_all_responses=False by default, description should only include success response code
|
||||
assert "200" in tool.description, f"Expected success response code in description for {tool.name}"
|
||||
assert "422" not in tool.description, f"Expected not to see 422 response in tool description for {tool.name}"
|
||||
|
||||
# With describe_full_response_schema=False by default, description should not include the full output schema, only an example
|
||||
assert "Example Response" in tool.description, f"Expected example response in description for {tool.name}"
|
||||
assert "Output Schema" not in tool.description, (
|
||||
f"Expected not to see output schema in description for {tool.name}"
|
||||
)
|
||||
|
||||
# Verify specific parameters are present in the appropriate tools
|
||||
list_items_tool = next((t for t in tools if t.name == "list_items_items__get"), None)
|
||||
assert list_items_tool is not None, "list_items tool not found"
|
||||
assert "skip" in list_items_tool.parameters["properties"], "Expected 'skip' parameter"
|
||||
assert "limit" in list_items_tool.parameters["properties"], "Expected 'limit' parameter"
|
||||
|
||||
|
||||
def test_tool_generation_with_full_schema(sample_app):
|
||||
"""Test that MCP tools include full response schema when requested."""
|
||||
# Create MCP server with full schema for all operations
|
||||
mcp_server = add_mcp_server(
|
||||
sample_app, serve_tools=True, base_url="http://localhost:8000", describe_full_response_schema=True
|
||||
)
|
||||
|
||||
# Extract tools for inspection
|
||||
tools = mcp_server._tool_manager.list_tools()
|
||||
|
||||
# Check all tools have the appropriate schema information
|
||||
for tool in tools:
|
||||
description = tool.description
|
||||
# Check that the tool includes information about the Item schema
|
||||
assert "Item" in description, f"Item schema should be included in the description for {tool.name}"
|
||||
assert "price" in description, f"Item properties should be included in the description for {tool.name}"
|
||||
|
||||
|
||||
def test_tool_generation_with_all_responses(sample_app):
|
||||
"""Test that MCP tools include all possible responses when requested."""
|
||||
# Create MCP server with all response status codes
|
||||
mcp_server = add_mcp_server(
|
||||
sample_app, serve_tools=True, base_url="http://localhost:8000", describe_all_responses=True
|
||||
)
|
||||
|
||||
# Extract tools for inspection
|
||||
tools = mcp_server._tool_manager.list_tools()
|
||||
|
||||
# Check all API tools include all response status codes
|
||||
for tool in tools:
|
||||
assert "200" in tool.description, f"Expected success response code in description for {tool.name}"
|
||||
assert "422" in tool.description, f"Expected 422 response code in description for {tool.name}"
|
||||
|
||||
|
||||
def test_tool_generation_with_all_responses_and_full_schema(sample_app):
|
||||
"""Test that MCP tools include all possible responses and full schema when requested."""
|
||||
# Create MCP server with all response status codes and full schema
|
||||
mcp_server = add_mcp_server(
|
||||
sample_app,
|
||||
serve_tools=True,
|
||||
base_url="http://localhost:8000",
|
||||
describe_all_responses=True,
|
||||
describe_full_response_schema=True,
|
||||
)
|
||||
|
||||
# Extract tools for inspection
|
||||
tools = mcp_server._tool_manager.list_tools()
|
||||
|
||||
# Check all tools include all response status codes and the full output schema
|
||||
for tool in tools:
|
||||
assert "200" in tool.description, f"Expected success response code in description for {tool.name}"
|
||||
assert "422" in tool.description, f"Expected 422 response code in description for {tool.name}"
|
||||
assert "Output Schema" in tool.description, f"Expected output schema in description for {tool.name}"
|
||||
|
||||
|
||||
def test_custom_tool_addition(sample_app):
|
||||
"""Test that custom tools can be added alongside API tools."""
|
||||
# Create MCP server with API tools
|
||||
mcp_server = add_mcp_server(sample_app, serve_tools=True, base_url="http://localhost:8000")
|
||||
|
||||
# Get initial tool count
|
||||
initial_tool_count = len(mcp_server._tool_manager.list_tools())
|
||||
|
||||
# Add a custom tool
|
||||
@mcp_server.tool()
|
||||
async def custom_tool() -> str:
|
||||
"""A custom tool for testing."""
|
||||
return "Test result"
|
||||
|
||||
# Extract tools for inspection
|
||||
tools = mcp_server._tool_manager.list_tools()
|
||||
|
||||
# Verify we have one more tool than before
|
||||
assert len(tools) == initial_tool_count + 1, f"Expected {initial_tool_count + 1} tools, got {len(tools)}"
|
||||
|
||||
# Find both API tools and custom tools
|
||||
list_items_tool = next((t for t in tools if t.name == "list_items_items__get"), None)
|
||||
assert list_items_tool is not None, "API tool (list_items) not found"
|
||||
|
||||
custom_tool_def = next((t for t in tools if t.name == "custom_tool"), None)
|
||||
assert custom_tool_def is not None, "Custom tool not found"
|
||||
assert custom_tool_def.description == "A custom tool for testing.", "Custom tool description not preserved"
|
||||
Reference in New Issue
Block a user