mirror of
https://github.com/tadata-org/fastapi_mcp.git
synced 2025-04-13 23:32:11 +03:00
enhance test suite and fix errors
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import httpx
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional, Any, List, Union, AsyncIterator
|
||||
from typing import Dict, Optional, Any, List, Union
|
||||
|
||||
from fastapi import FastAPI, Request, APIRouter
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
@@ -10,6 +9,7 @@ import mcp.types as types
|
||||
|
||||
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
|
||||
from fastapi_mcp.transport.sse import FastApiSseTransport
|
||||
from fastapi_mcp.types import AsyncClientProtocol
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
@@ -26,7 +26,24 @@ class FastApiMCP:
|
||||
base_url: Optional[str] = None,
|
||||
describe_all_responses: bool = False,
|
||||
describe_full_response_schema: bool = False,
|
||||
http_client: Optional[AsyncClientProtocol] = None,
|
||||
):
|
||||
"""
|
||||
Create an MCP server from a FastAPI app.
|
||||
|
||||
Args:
|
||||
fastapi: The FastAPI application
|
||||
name: Name for the MCP server (defaults to app.title)
|
||||
description: Description for the MCP server (defaults to app.description)
|
||||
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
|
||||
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
|
||||
as the root path would be different when the app is deployed.
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
|
||||
http_client: Optional HTTP client to use for API calls. If not provided, a new httpx.AsyncClient will be created.
|
||||
This is primarily for testing purposes.
|
||||
"""
|
||||
|
||||
self.operation_map: Dict[str, Dict[str, Any]]
|
||||
self.tools: List[types.Tool]
|
||||
|
||||
@@ -38,27 +55,11 @@ class FastApiMCP:
|
||||
self._describe_all_responses = describe_all_responses
|
||||
self._describe_full_response_schema = describe_full_response_schema
|
||||
|
||||
self._http_client = http_client or httpx.AsyncClient()
|
||||
|
||||
self.server = self.create_server()
|
||||
|
||||
def create_server(self) -> Server:
|
||||
"""
|
||||
Create an MCP server from the FastAPI app.
|
||||
|
||||
Args:
|
||||
fastapi: The FastAPI application
|
||||
name: Name for the MCP server (defaults to app.title)
|
||||
description: Description for the MCP server (defaults to app.description)
|
||||
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
|
||||
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
|
||||
as the root path would be different when the app is deployed.
|
||||
describe_all_responses: Whether to include all possible response schemas in tool descriptions
|
||||
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The created MCP Server instance (NOT mounted to the app)
|
||||
- A mapping of operation IDs to operation details for HTTP execution
|
||||
"""
|
||||
# Get OpenAPI schema from FastAPI app
|
||||
openapi_schema = get_openapi(
|
||||
title=self.fastapi.title,
|
||||
@@ -93,23 +94,12 @@ class FastApiMCP:
|
||||
if self._base_url.endswith("/"):
|
||||
self._base_url = self._base_url[:-1]
|
||||
|
||||
# Create the MCP server
|
||||
# Create the MCP lowlevel server
|
||||
mcp_server: Server = Server(self.name, self.description)
|
||||
|
||||
# Create a lifespan context manager to store the base_url and operation_map
|
||||
@asynccontextmanager
|
||||
async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]:
|
||||
# Store context data that will be available to all server handlers
|
||||
context = {"base_url": self._base_url, "operation_map": self.operation_map}
|
||||
yield context
|
||||
|
||||
# Use our custom lifespan
|
||||
mcp_server.lifespan = server_lifespan
|
||||
|
||||
# Register handlers for tools
|
||||
@mcp_server.list_tools()
|
||||
async def handle_list_tools() -> List[types.Tool]:
|
||||
"""Handler for the tools/list request"""
|
||||
return self.tools
|
||||
|
||||
# Register the tool call handler
|
||||
@@ -117,14 +107,13 @@ class FastApiMCP:
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: Dict[str, Any]
|
||||
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
|
||||
"""Handler for the tools/call request"""
|
||||
# Get context from server lifespan
|
||||
ctx = mcp_server.request_context
|
||||
base_url = ctx.lifespan_context["base_url"]
|
||||
operation_map = ctx.lifespan_context["operation_map"]
|
||||
|
||||
# Execute the tool
|
||||
return await self.execute_api_tool(base_url, name, arguments, operation_map)
|
||||
return await self._execute_api_tool(
|
||||
client=self._http_client,
|
||||
base_url=self._base_url or "",
|
||||
tool_name=name,
|
||||
arguments=arguments,
|
||||
operation_map=self.operation_map,
|
||||
)
|
||||
|
||||
return mcp_server
|
||||
|
||||
@@ -168,8 +157,13 @@ class FastApiMCP:
|
||||
|
||||
logger.info(f"MCP server listening at {mount_path}")
|
||||
|
||||
async def execute_api_tool(
|
||||
self, base_url: str, tool_name: str, arguments: Dict[str, Any], operation_map: Dict[str, Dict[str, Any]]
|
||||
async def _execute_api_tool(
|
||||
self,
|
||||
client: AsyncClientProtocol,
|
||||
base_url: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
operation_map: Dict[str, Dict[str, Any]],
|
||||
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
|
||||
"""
|
||||
Execute an MCP tool by making an HTTP request to the corresponding API endpoint.
|
||||
@@ -179,12 +173,13 @@ class FastApiMCP:
|
||||
tool_name: The name of the tool to execute
|
||||
arguments: The arguments for the tool
|
||||
operation_map: A mapping from tool names to operation details
|
||||
client: Optional HTTP client to use (primarily for testing)
|
||||
|
||||
Returns:
|
||||
The result as MCP content types
|
||||
"""
|
||||
if tool_name not in operation_map:
|
||||
return [types.TextContent(type="text", text=f"Unknown tool: {tool_name}")]
|
||||
raise Exception(f"Unknown tool: {tool_name}")
|
||||
|
||||
operation = operation_map[tool_name]
|
||||
path: str = operation["path"]
|
||||
@@ -192,7 +187,6 @@ class FastApiMCP:
|
||||
parameters: List[Dict[str, Any]] = operation.get("parameters", [])
|
||||
arguments = arguments.copy() if arguments else {} # Deep copy arguments to avoid mutating the original
|
||||
|
||||
# Prepare URL with path parameters
|
||||
url = f"{base_url}{path}"
|
||||
for param in parameters:
|
||||
if param.get("in") == "path" and param.get("name") in arguments:
|
||||
@@ -201,7 +195,6 @@ class FastApiMCP:
|
||||
raise ValueError(f"Parameter name is None for parameter: {param}")
|
||||
url = url.replace(f"{{{param_name}}}", str(arguments.pop(param_name)))
|
||||
|
||||
# Prepare query parameters
|
||||
query = {}
|
||||
for param in parameters:
|
||||
if param.get("in") == "query" and param.get("name") in arguments:
|
||||
@@ -210,7 +203,6 @@ class FastApiMCP:
|
||||
raise ValueError(f"Parameter name is None for parameter: {param}")
|
||||
query[param_name] = arguments.pop(param_name)
|
||||
|
||||
# Prepare headers
|
||||
headers = {}
|
||||
for param in parameters:
|
||||
if param.get("in") == "header" and param.get("name") in arguments:
|
||||
@@ -219,32 +211,57 @@ class FastApiMCP:
|
||||
raise ValueError(f"Parameter name is None for parameter: {param}")
|
||||
headers[param_name] = arguments.pop(param_name)
|
||||
|
||||
# Prepare request body (remaining kwargs)
|
||||
body = arguments if arguments else None
|
||||
|
||||
try:
|
||||
# Make request
|
||||
logger.debug(f"Making {method.upper()} request to {url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
if method.lower() == "get":
|
||||
response = await client.get(url, params=query, headers=headers)
|
||||
elif method.lower() == "post":
|
||||
response = await client.post(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "put":
|
||||
response = await client.put(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "delete":
|
||||
response = await client.delete(url, params=query, headers=headers)
|
||||
elif method.lower() == "patch":
|
||||
response = await client.patch(url, params=query, headers=headers, json=body)
|
||||
else:
|
||||
return [types.TextContent(type="text", text=f"Unsupported HTTP method: {method}")]
|
||||
response = await self._request(client, method, url, query, headers, body)
|
||||
|
||||
# Process response
|
||||
# TODO: Better typing for the AsyncClientProtocol. It should return a ResponseProtocol that has a json() method that returns a dict/list/etc.
|
||||
try:
|
||||
result = response.json()
|
||||
return [types.TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||
result_text = json.dumps(result, indent=2)
|
||||
except json.JSONDecodeError:
|
||||
if hasattr(response, "text"):
|
||||
result_text = response.text
|
||||
else:
|
||||
result_text = response.content
|
||||
|
||||
# If not raising an exception, the MCP server will return the result as a regular text response, without marking it as an error.
|
||||
# TODO: Use a raise_for_status() method on the response (it needs to also be implemented in the AsyncClientProtocol)
|
||||
if 400 <= response.status_code < 600:
|
||||
raise Exception(
|
||||
f"Error calling {tool_name}. Status code: {response.status_code}. Response: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
return [types.TextContent(type="text", text=result_text)]
|
||||
except ValueError:
|
||||
return [types.TextContent(type="text", text=response.text)]
|
||||
return [types.TextContent(type="text", text=result_text)]
|
||||
|
||||
except Exception as e:
|
||||
return [types.TextContent(type="text", text=f"Error calling {tool_name}: {str(e)}")]
|
||||
logger.exception(f"Error calling {tool_name}")
|
||||
raise e
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
client: AsyncClientProtocol,
|
||||
method: str,
|
||||
url: str,
|
||||
query: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
body: Optional[Any],
|
||||
) -> Any:
|
||||
"""Helper method to make the actual HTTP request"""
|
||||
if method.lower() == "get":
|
||||
return await client.get(url, params=query, headers=headers)
|
||||
elif method.lower() == "post":
|
||||
return await client.post(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "put":
|
||||
return await client.put(url, params=query, headers=headers, json=body)
|
||||
elif method.lower() == "delete":
|
||||
return await client.delete(url, params=query, headers=headers)
|
||||
elif method.lower() == "patch":
|
||||
return await client.patch(url, params=query, headers=headers, json=body)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
46
fastapi_mcp/types.py
Normal file
46
fastapi_mcp/types.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from typing import Any, Protocol, Optional, Dict
|
||||
|
||||
|
||||
class BaseType(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AsyncClientProtocol(Protocol):
|
||||
"""Protocol defining the interface for async HTTP clients."""
|
||||
|
||||
async def get(
|
||||
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
|
||||
) -> Any: ...
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json: Optional[Any] = None,
|
||||
) -> Any: ...
|
||||
|
||||
async def put(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json: Optional[Any] = None,
|
||||
) -> Any: ...
|
||||
|
||||
async def delete(
|
||||
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
|
||||
) -> Any: ...
|
||||
|
||||
async def patch(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json: Optional[Any] = None,
|
||||
) -> Any: ...
|
||||
0
fastapi_mcp/utils/__init__.py
Normal file
0
fastapi_mcp/utils/__init__.py
Normal file
63
fastapi_mcp/utils/testing.py
Normal file
63
fastapi_mcp/utils/testing.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from fastapi_mcp.server import AsyncClientProtocol
|
||||
|
||||
|
||||
class FastAPITestClient(AsyncClientProtocol):
|
||||
def __init__(self, app: FastAPI):
|
||||
self.client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
async def get(
|
||||
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
|
||||
) -> Any:
|
||||
response = self.client.get(url, params=params, headers=headers)
|
||||
return self._wrap_response(response)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json: Optional[Any] = None,
|
||||
) -> Any:
|
||||
response = self.client.post(url, params=params, headers=headers, json=json)
|
||||
return self._wrap_response(response)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json: Optional[Any] = None,
|
||||
) -> Any:
|
||||
response = self.client.put(url, params=params, headers=headers, json=json)
|
||||
return self._wrap_response(response)
|
||||
|
||||
async def delete(
|
||||
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
|
||||
) -> Any:
|
||||
response = self.client.delete(url, params=params, headers=headers)
|
||||
return self._wrap_response(response)
|
||||
|
||||
async def patch(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json: Optional[Any] = None,
|
||||
) -> Any:
|
||||
response = self.client.patch(url, params=params, headers=headers, json=json)
|
||||
return self._wrap_response(response)
|
||||
|
||||
def _wrap_response(self, response: Any) -> Any:
|
||||
response.json = (
|
||||
lambda: json.loads(response.content) if hasattr(response, "content") and response.content else None
|
||||
)
|
||||
return response
|
||||
@@ -8,55 +8,3 @@ from .fixtures.types import * # noqa: F403
|
||||
from .fixtures.example_data import * # noqa: F403
|
||||
from .fixtures.simple_app import * # noqa: F403
|
||||
from .fixtures.complex_app import * # noqa: F403
|
||||
|
||||
# Add specific fixtures for MCP testing
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from fastapi_mcp import FastApiMCP
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server(simple_fastapi_app):
|
||||
"""
|
||||
Create a basic MCP server instance for the simple_fastapi_app.
|
||||
This is a utility fixture to be used by multiple tests.
|
||||
"""
|
||||
return FastApiMCP(
|
||||
simple_fastapi_app,
|
||||
name="Test MCP Server",
|
||||
description="Test MCP server for unit testing",
|
||||
base_url="http://testserver",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_mcp_server(complex_fastapi_app):
|
||||
"""
|
||||
Create a MCP server instance for the complex_fastapi_app.
|
||||
This is a utility fixture to be used by multiple tests.
|
||||
"""
|
||||
return FastApiMCP(
|
||||
complex_fastapi_app,
|
||||
name="Complex Test MCP Server",
|
||||
description="Complex test MCP server for unit testing",
|
||||
base_url="http://testserver",
|
||||
describe_all_responses=True,
|
||||
describe_full_response_schema=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(simple_fastapi_app):
|
||||
"""
|
||||
Create a test client for the simple_fastapi_app.
|
||||
"""
|
||||
return TestClient(simple_fastapi_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_client(complex_fastapi_app):
|
||||
"""
|
||||
Create a test client for the complex_fastapi_app.
|
||||
"""
|
||||
return TestClient(complex_fastapi_app)
|
||||
|
||||
25
tests/fixtures/simple_app.py
vendored
25
tests/fixtures/simple_app.py
vendored
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import FastAPI, Query, Path, Body
|
||||
from fastapi import FastAPI, Query, Path, Body, HTTPException
|
||||
import pytest
|
||||
|
||||
from .types import Item
|
||||
@@ -14,6 +14,12 @@ def simple_fastapi_app() -> FastAPI:
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
items = [
|
||||
Item(id=1, name="Item 1", price=10.0, tags=["tag1", "tag2"], description="Item 1 description"),
|
||||
Item(id=2, name="Item 2", price=20.0, tags=["tag2", "tag3"]),
|
||||
Item(id=3, name="Item 3", price=30.0, tags=["tag3", "tag4"], description="Item 3 description"),
|
||||
]
|
||||
|
||||
@app.get("/items/", response_model=List[Item], tags=["items"], operation_id="list_items")
|
||||
async def list_items(
|
||||
skip: int = Query(0, description="Number of items to skip"),
|
||||
@@ -21,11 +27,7 @@ def simple_fastapi_app() -> FastAPI:
|
||||
sort_by: Optional[str] = Query(None, description="Field to sort by"),
|
||||
):
|
||||
"""List all items with pagination and sorting options."""
|
||||
return [
|
||||
Item(id=1, name="Item 1", price=10.0, tags=["tag1", "tag2"], description="Item 1 description"),
|
||||
Item(id=2, name="Item 2", price=20.0, tags=["tag2", "tag3"]),
|
||||
Item(id=3, name="Item 3", price=30.0, tags=["tag3", "tag4"], description="Item 3 description"),
|
||||
]
|
||||
return items[skip : skip + limit]
|
||||
|
||||
@app.get("/items/{item_id}", response_model=Item, tags=["items"], operation_id="get_item")
|
||||
async def read_item(
|
||||
@@ -33,11 +35,15 @@ def simple_fastapi_app() -> FastAPI:
|
||||
include_details: bool = Query(False, description="Include additional details"),
|
||||
):
|
||||
"""Get a specific item by its ID with optional details."""
|
||||
return Item(id=item_id, name="Test Item", price=10.0, tags=["tag1", "tag2"])
|
||||
found_item = next((item for item in items if item.id == item_id), None)
|
||||
if found_item is None:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
return found_item
|
||||
|
||||
@app.post("/items/", response_model=Item, tags=["items"], operation_id="create_item")
|
||||
async def create_item(item: Item = Body(..., description="The item to create")):
|
||||
"""Create a new item in the database."""
|
||||
items.append(item)
|
||||
return item
|
||||
|
||||
@app.put("/items/{item_id}", response_model=Item, tags=["items"], operation_id="update_item")
|
||||
@@ -54,4 +60,9 @@ def simple_fastapi_app() -> FastAPI:
|
||||
"""Delete an item from the database."""
|
||||
return None
|
||||
|
||||
@app.get("/error", status_code=200, tags=["error"], operation_id="raise_error")
|
||||
async def raise_error():
|
||||
"""Fail on purpose and cause a 500 error."""
|
||||
raise Exception("This is a test error")
|
||||
|
||||
return app
|
||||
|
||||
218
tests/test_mcp_complex_app.py
Normal file
218
tests/test_mcp_complex_app.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.shared.memory import create_connected_server_and_client_session
|
||||
from fastapi import FastAPI
|
||||
|
||||
from fastapi_mcp import FastApiMCP
|
||||
from fastapi_mcp.utils.testing import FastAPITestClient
|
||||
|
||||
from .fixtures.types import Product, Customer, OrderResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lowlevel_server_complex_app(complex_fastapi_app: FastAPI) -> Server:
|
||||
mcp = FastApiMCP(
|
||||
complex_fastapi_app,
|
||||
name="Test MCP Server",
|
||||
description="Test description",
|
||||
base_url="",
|
||||
http_client=FastAPITestClient(complex_fastapi_app),
|
||||
)
|
||||
return mcp.server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(lowlevel_server_complex_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
tools_result = await client_session.list_tools()
|
||||
|
||||
assert len(tools_result.tools) > 0
|
||||
|
||||
tool_names = [tool.name for tool in tools_result.tools]
|
||||
expected_operations = ["list_products", "get_product", "create_order", "get_customer"]
|
||||
for op in expected_operations:
|
||||
assert op in tool_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_list_products_default(lowlevel_server_complex_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool("list_products", {})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert "items" in result
|
||||
assert result["total"] == 1
|
||||
assert result["page"] == 1
|
||||
assert len(result["items"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_list_products_with_filters(lowlevel_server_complex_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool(
|
||||
"list_products",
|
||||
{"category": "electronics", "min_price": 10.0, "page": 1, "size": 10, "in_stock_only": True},
|
||||
)
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert "items" in result
|
||||
assert result["page"] == 1
|
||||
assert result["size"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_product(lowlevel_server_complex_app: Server, example_product: Product):
|
||||
product_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool("get_product", {"product_id": product_id})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert result["id"] == product_id
|
||||
assert "name" in result
|
||||
assert "price" in result
|
||||
assert "description" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_product_with_options(lowlevel_server_complex_app: Server):
|
||||
product_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool(
|
||||
"get_product", {"product_id": product_id, "include_unavailable": True}
|
||||
)
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert result["id"] == product_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_create_order(lowlevel_server_complex_app: Server, example_order_response: OrderResponse):
|
||||
customer_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
|
||||
product_id = "123e4567-e89b-12d3-a456-426614174001" # Valid UUID format
|
||||
shipping_address_id = "123e4567-e89b-12d3-a456-426614174002" # Valid UUID format
|
||||
|
||||
order_request = {
|
||||
"customer_id": customer_id,
|
||||
"items": [{"product_id": product_id, "quantity": 2, "unit_price": 29.99, "total": 59.98}],
|
||||
"shipping_address_id": shipping_address_id,
|
||||
"payment_method": "credit_card",
|
||||
}
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool("create_order", order_request)
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert result["customer_id"] == customer_id
|
||||
assert "id" in result
|
||||
assert "status" in result
|
||||
assert "items" in result
|
||||
assert len(result["items"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_create_order_validation_error(lowlevel_server_complex_app: Server):
|
||||
# Missing required fields
|
||||
order_request = {
|
||||
# Missing customer_id
|
||||
"items": [],
|
||||
# Missing shipping_address_id
|
||||
"payment_method": "credit_card",
|
||||
}
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool("create_order", order_request)
|
||||
|
||||
assert response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
assert "422" in text_content.text or "validation" in text_content.text.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_customer(lowlevel_server_complex_app: Server, example_customer: Customer):
|
||||
customer_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool("get_customer", {"customer_id": customer_id})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert result["id"] == customer_id
|
||||
assert "full_name" in result
|
||||
assert "email" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_customer_with_options(lowlevel_server_complex_app: Server):
|
||||
customer_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
response = await client_session.call_tool(
|
||||
"get_customer",
|
||||
{
|
||||
"customer_id": customer_id,
|
||||
"include_orders": True,
|
||||
"include_payment_methods": True,
|
||||
"fields": ["full_name", "email", "orders"],
|
||||
},
|
||||
)
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert result["id"] == customer_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_missing_parameter(lowlevel_server_complex_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
|
||||
# Missing required product_id parameter
|
||||
response = await client_session.call_tool("get_product", {})
|
||||
|
||||
assert response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
assert (
|
||||
"422" in text_content.text
|
||||
or "parameter" in text_content.text.lower()
|
||||
or "field" in text_content.text.lower()
|
||||
)
|
||||
233
tests/test_mcp_simple_app.py
Normal file
233
tests/test_mcp_simple_app.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.shared.memory import create_connected_server_and_client_session
|
||||
from fastapi import FastAPI
|
||||
|
||||
from fastapi_mcp import FastApiMCP
|
||||
from fastapi_mcp.utils.testing import FastAPITestClient
|
||||
|
||||
from .fixtures.types import Item
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lowlevel_server_simple_app(simple_fastapi_app: FastAPI) -> Server:
|
||||
mcp = FastApiMCP(
|
||||
simple_fastapi_app,
|
||||
name="Test MCP Server",
|
||||
description="Test description",
|
||||
base_url="",
|
||||
http_client=FastAPITestClient(simple_fastapi_app),
|
||||
)
|
||||
return mcp.server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(lowlevel_server_simple_app: Server):
|
||||
"""Test listing tools via direct MCP connection."""
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
tools_result = await client_session.list_tools()
|
||||
|
||||
assert len(tools_result.tools) > 0
|
||||
|
||||
tool_names = [tool.name for tool in tools_result.tools]
|
||||
expected_operations = ["list_items", "get_item", "create_item", "update_item", "delete_item", "raise_error"]
|
||||
for op in expected_operations:
|
||||
assert op in tool_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_item_1(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("get_item", {"item_id": 1})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result: dict = json.loads(text_content.text)
|
||||
parsed_result = Item(**result)
|
||||
|
||||
assert parsed_result.id == 1
|
||||
assert parsed_result.name == "Item 1"
|
||||
assert parsed_result.price == 10.0
|
||||
assert parsed_result.tags == ["tag1", "tag2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_item_2(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("get_item", {"item_id": 2})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result: dict = json.loads(text_content.text)
|
||||
parsed_result = Item(**result)
|
||||
|
||||
assert parsed_result.id == 2
|
||||
assert parsed_result.name == "Item 2"
|
||||
assert parsed_result.price == 20.0
|
||||
assert parsed_result.tags == ["tag2", "tag3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_raise_error(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("raise_error", {})
|
||||
|
||||
assert response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
assert "500" in text_content.text
|
||||
assert "internal server error" in text_content.text.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("get_item", {})
|
||||
|
||||
assert response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
assert "item_id" in text_content.text.lower() or "missing" in text_content.text.lower()
|
||||
assert "422" in text_content.text, "Expected a 422 status to appear in the response text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_tool_arguments(lowlevel_server_simple_app: Server):
|
||||
test_item = {
|
||||
"id": 42,
|
||||
"name": "Test Item",
|
||||
"description": "A test item for MCP",
|
||||
"price": 9.99,
|
||||
"tags": ["test", "mcp"],
|
||||
}
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("create_item", test_item)
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert result["id"] == test_item["id"]
|
||||
assert result["name"] == test_item["name"]
|
||||
assert result["price"] == test_item["price"]
|
||||
assert result["tags"] == test_item["tags"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_list_items_default(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("list_items", {})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
results = json.loads(text_content.text)
|
||||
assert len(results) == 3 # Default should return all three items with default pagination
|
||||
|
||||
# Check first item matches expected data
|
||||
item = results[0]
|
||||
assert item["id"] == 1
|
||||
assert item["name"] == "Item 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_list_items_with_pagination(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("list_items", {"skip": 1, "limit": 1})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
results = json.loads(text_content.text)
|
||||
assert len(results) == 1
|
||||
|
||||
# Should be the second item in the list (after skipping the first)
|
||||
item = results[0]
|
||||
assert item["id"] == 2
|
||||
assert item["name"] == "Item 2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_item_not_found(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("get_item", {"item_id": 999})
|
||||
|
||||
assert response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
assert "404" in text_content.text
|
||||
assert "not found" in text_content.text.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_update_item(lowlevel_server_simple_app: Server):
|
||||
test_update = {
|
||||
"item_id": 3,
|
||||
"id": 3,
|
||||
"name": "Updated Item 3",
|
||||
"description": "Updated description",
|
||||
"price": 35.99,
|
||||
"tags": ["updated", "modified"],
|
||||
}
|
||||
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("update_item", test_update)
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result = json.loads(text_content.text)
|
||||
|
||||
assert result["id"] == test_update["item_id"]
|
||||
assert result["name"] == test_update["name"]
|
||||
assert result["description"] == test_update["description"]
|
||||
assert result["price"] == test_update["price"]
|
||||
assert result["tags"] == test_update["tags"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_delete_item(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("delete_item", {"item_id": 3})
|
||||
|
||||
assert not response.isError
|
||||
# The endpoint returns 204 No Content, so we expect an empty response
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
assert (
|
||||
text_content.text.strip() == "{}" or text_content.text.strip() == "null" or text_content.text.strip() == ""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_item_with_details(lowlevel_server_simple_app: Server):
|
||||
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
|
||||
response = await client_session.call_tool("get_item", {"item_id": 1, "include_details": True})
|
||||
|
||||
assert not response.isError
|
||||
assert len(response.content) > 0
|
||||
|
||||
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
|
||||
result: dict = json.loads(text_content.text)
|
||||
parsed_result = Item(**result)
|
||||
|
||||
assert parsed_result.id == 1
|
||||
assert parsed_result.name == "Item 1"
|
||||
assert parsed_result.price == 10.0
|
||||
assert parsed_result.tags == ["tag1", "tag2"]
|
||||
assert parsed_result.description == "Item 1 description"
|
||||
Reference in New Issue
Block a user