enhance test suite and fix errors

This commit is contained in:
Shahar Abramov
2025-04-09 17:54:08 +03:00
parent 9bc99f5a20
commit e0d405db82
8 changed files with 660 additions and 124 deletions

View File

@@ -1,7 +1,6 @@
import json
import httpx
from contextlib import asynccontextmanager
from typing import Dict, Optional, Any, List, Union, AsyncIterator
from typing import Dict, Optional, Any, List, Union
from fastapi import FastAPI, Request, APIRouter
from fastapi.openapi.utils import get_openapi
@@ -10,6 +9,7 @@ import mcp.types as types
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
from fastapi_mcp.transport.sse import FastApiSseTransport
from fastapi_mcp.types import AsyncClientProtocol
from logging import getLogger
@@ -26,7 +26,24 @@ class FastApiMCP:
base_url: Optional[str] = None,
describe_all_responses: bool = False,
describe_full_response_schema: bool = False,
http_client: Optional[AsyncClientProtocol] = None,
):
"""
Create an MCP server from a FastAPI app.
Args:
fastapi: The FastAPI application
name: Name for the MCP server (defaults to app.title)
description: Description for the MCP server (defaults to app.description)
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
as the root path would be different when the app is deployed.
describe_all_responses: Whether to include all possible response schemas in tool descriptions
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
http_client: Optional HTTP client to use for API calls. If not provided, a new httpx.AsyncClient will be created.
This is primarily for testing purposes.
"""
self.operation_map: Dict[str, Dict[str, Any]]
self.tools: List[types.Tool]
@@ -38,27 +55,11 @@ class FastApiMCP:
self._describe_all_responses = describe_all_responses
self._describe_full_response_schema = describe_full_response_schema
self._http_client = http_client or httpx.AsyncClient()
self.server = self.create_server()
def create_server(self) -> Server:
"""
Create an MCP server from the FastAPI app.
Args:
fastapi: The FastAPI application
name: Name for the MCP server (defaults to app.title)
description: Description for the MCP server (defaults to app.description)
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
as the root path would be different when the app is deployed.
describe_all_responses: Whether to include all possible response schemas in tool descriptions
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
Returns:
A tuple containing:
- The created MCP Server instance (NOT mounted to the app)
- A mapping of operation IDs to operation details for HTTP execution
"""
# Get OpenAPI schema from FastAPI app
openapi_schema = get_openapi(
title=self.fastapi.title,
@@ -93,23 +94,12 @@ class FastApiMCP:
if self._base_url.endswith("/"):
self._base_url = self._base_url[:-1]
# Create the MCP server
# Create the MCP lowlevel server
mcp_server: Server = Server(self.name, self.description)
# Create a lifespan context manager to store the base_url and operation_map
@asynccontextmanager
async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]:
# Store context data that will be available to all server handlers
context = {"base_url": self._base_url, "operation_map": self.operation_map}
yield context
# Use our custom lifespan
mcp_server.lifespan = server_lifespan
# Register handlers for tools
@mcp_server.list_tools()
async def handle_list_tools() -> List[types.Tool]:
"""Handler for the tools/list request"""
return self.tools
# Register the tool call handler
@@ -117,14 +107,13 @@ class FastApiMCP:
async def handle_call_tool(
name: str, arguments: Dict[str, Any]
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
"""Handler for the tools/call request"""
# Get context from server lifespan
ctx = mcp_server.request_context
base_url = ctx.lifespan_context["base_url"]
operation_map = ctx.lifespan_context["operation_map"]
# Execute the tool
return await self.execute_api_tool(base_url, name, arguments, operation_map)
return await self._execute_api_tool(
client=self._http_client,
base_url=self._base_url or "",
tool_name=name,
arguments=arguments,
operation_map=self.operation_map,
)
return mcp_server
@@ -168,8 +157,13 @@ class FastApiMCP:
logger.info(f"MCP server listening at {mount_path}")
async def execute_api_tool(
self, base_url: str, tool_name: str, arguments: Dict[str, Any], operation_map: Dict[str, Dict[str, Any]]
async def _execute_api_tool(
self,
client: AsyncClientProtocol,
base_url: str,
tool_name: str,
arguments: Dict[str, Any],
operation_map: Dict[str, Dict[str, Any]],
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
"""
Execute an MCP tool by making an HTTP request to the corresponding API endpoint.
@@ -179,12 +173,13 @@ class FastApiMCP:
tool_name: The name of the tool to execute
arguments: The arguments for the tool
operation_map: A mapping from tool names to operation details
client: Optional HTTP client to use (primarily for testing)
Returns:
The result as MCP content types
"""
if tool_name not in operation_map:
return [types.TextContent(type="text", text=f"Unknown tool: {tool_name}")]
raise Exception(f"Unknown tool: {tool_name}")
operation = operation_map[tool_name]
path: str = operation["path"]
@@ -192,7 +187,6 @@ class FastApiMCP:
parameters: List[Dict[str, Any]] = operation.get("parameters", [])
arguments = arguments.copy() if arguments else {} # Deep copy arguments to avoid mutating the original
# Prepare URL with path parameters
url = f"{base_url}{path}"
for param in parameters:
if param.get("in") == "path" and param.get("name") in arguments:
@@ -201,7 +195,6 @@ class FastApiMCP:
raise ValueError(f"Parameter name is None for parameter: {param}")
url = url.replace(f"{{{param_name}}}", str(arguments.pop(param_name)))
# Prepare query parameters
query = {}
for param in parameters:
if param.get("in") == "query" and param.get("name") in arguments:
@@ -210,7 +203,6 @@ class FastApiMCP:
raise ValueError(f"Parameter name is None for parameter: {param}")
query[param_name] = arguments.pop(param_name)
# Prepare headers
headers = {}
for param in parameters:
if param.get("in") == "header" and param.get("name") in arguments:
@@ -219,32 +211,57 @@ class FastApiMCP:
raise ValueError(f"Parameter name is None for parameter: {param}")
headers[param_name] = arguments.pop(param_name)
# Prepare request body (remaining kwargs)
body = arguments if arguments else None
try:
# Make request
logger.debug(f"Making {method.upper()} request to {url}")
async with httpx.AsyncClient() as client:
if method.lower() == "get":
response = await client.get(url, params=query, headers=headers)
elif method.lower() == "post":
response = await client.post(url, params=query, headers=headers, json=body)
elif method.lower() == "put":
response = await client.put(url, params=query, headers=headers, json=body)
elif method.lower() == "delete":
response = await client.delete(url, params=query, headers=headers)
elif method.lower() == "patch":
response = await client.patch(url, params=query, headers=headers, json=body)
else:
return [types.TextContent(type="text", text=f"Unsupported HTTP method: {method}")]
response = await self._request(client, method, url, query, headers, body)
# Process response
# TODO: Better typing for the AsyncClientProtocol. It should return a ResponseProtocol that has a json() method that returns a dict/list/etc.
try:
result = response.json()
return [types.TextContent(type="text", text=json.dumps(result, indent=2))]
result_text = json.dumps(result, indent=2)
except json.JSONDecodeError:
if hasattr(response, "text"):
result_text = response.text
else:
result_text = response.content
# If not raising an exception, the MCP server will return the result as a regular text response, without marking it as an error.
# TODO: Use a raise_for_status() method on the response (it needs to also be implemented in the AsyncClientProtocol)
if 400 <= response.status_code < 600:
raise Exception(
f"Error calling {tool_name}. Status code: {response.status_code}. Response: {response.text}"
)
try:
return [types.TextContent(type="text", text=result_text)]
except ValueError:
return [types.TextContent(type="text", text=response.text)]
return [types.TextContent(type="text", text=result_text)]
except Exception as e:
return [types.TextContent(type="text", text=f"Error calling {tool_name}: {str(e)}")]
logger.exception(f"Error calling {tool_name}")
raise e
async def _request(
self,
client: AsyncClientProtocol,
method: str,
url: str,
query: Dict[str, Any],
headers: Dict[str, str],
body: Optional[Any],
) -> Any:
"""Helper method to make the actual HTTP request"""
if method.lower() == "get":
return await client.get(url, params=query, headers=headers)
elif method.lower() == "post":
return await client.post(url, params=query, headers=headers, json=body)
elif method.lower() == "put":
return await client.put(url, params=query, headers=headers, json=body)
elif method.lower() == "delete":
return await client.delete(url, params=query, headers=headers)
elif method.lower() == "patch":
return await client.patch(url, params=query, headers=headers, json=body)
else:
raise ValueError(f"Unsupported HTTP method: {method}")

46
fastapi_mcp/types.py Normal file
View File

@@ -0,0 +1,46 @@
from pydantic import BaseModel, ConfigDict
from typing import Any, Protocol, Optional, Dict
class BaseType(BaseModel):
model_config = ConfigDict(extra="forbid")
class AsyncClientProtocol(Protocol):
"""Protocol defining the interface for async HTTP clients."""
async def get(
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
) -> Any: ...
async def post(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json: Optional[Any] = None,
) -> Any: ...
async def put(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json: Optional[Any] = None,
) -> Any: ...
async def delete(
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
) -> Any: ...
async def patch(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json: Optional[Any] = None,
) -> Any: ...

View File

View File

@@ -0,0 +1,63 @@
import json
from typing import Any, Dict, Optional
from fastapi import FastAPI
from fastapi.testclient import TestClient
from fastapi_mcp.server import AsyncClientProtocol
class FastAPITestClient(AsyncClientProtocol):
def __init__(self, app: FastAPI):
self.client = TestClient(app, raise_server_exceptions=False)
async def get(
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
) -> Any:
response = self.client.get(url, params=params, headers=headers)
return self._wrap_response(response)
async def post(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json: Optional[Any] = None,
) -> Any:
response = self.client.post(url, params=params, headers=headers, json=json)
return self._wrap_response(response)
async def put(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json: Optional[Any] = None,
) -> Any:
response = self.client.put(url, params=params, headers=headers, json=json)
return self._wrap_response(response)
async def delete(
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
) -> Any:
response = self.client.delete(url, params=params, headers=headers)
return self._wrap_response(response)
async def patch(
self,
url: str,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
json: Optional[Any] = None,
) -> Any:
response = self.client.patch(url, params=params, headers=headers, json=json)
return self._wrap_response(response)
def _wrap_response(self, response: Any) -> Any:
response.json = (
lambda: json.loads(response.content) if hasattr(response, "content") and response.content else None
)
return response

View File

@@ -8,55 +8,3 @@ from .fixtures.types import * # noqa: F403
from .fixtures.example_data import * # noqa: F403
from .fixtures.simple_app import * # noqa: F403
from .fixtures.complex_app import * # noqa: F403
# Add specific fixtures for MCP testing
import pytest
from fastapi.testclient import TestClient
from fastapi_mcp import FastApiMCP
@pytest.fixture
def mcp_server(simple_fastapi_app):
"""
Create a basic MCP server instance for the simple_fastapi_app.
This is a utility fixture to be used by multiple tests.
"""
return FastApiMCP(
simple_fastapi_app,
name="Test MCP Server",
description="Test MCP server for unit testing",
base_url="http://testserver",
)
@pytest.fixture
def complex_mcp_server(complex_fastapi_app):
"""
Create a MCP server instance for the complex_fastapi_app.
This is a utility fixture to be used by multiple tests.
"""
return FastApiMCP(
complex_fastapi_app,
name="Complex Test MCP Server",
description="Complex test MCP server for unit testing",
base_url="http://testserver",
describe_all_responses=True,
describe_full_response_schema=True,
)
@pytest.fixture
def client(simple_fastapi_app):
"""
Create a test client for the simple_fastapi_app.
"""
return TestClient(simple_fastapi_app)
@pytest.fixture
def complex_client(complex_fastapi_app):
"""
Create a test client for the complex_fastapi_app.
"""
return TestClient(complex_fastapi_app)

View File

@@ -1,6 +1,6 @@
from typing import Optional, List
from fastapi import FastAPI, Query, Path, Body
from fastapi import FastAPI, Query, Path, Body, HTTPException
import pytest
from .types import Item
@@ -14,6 +14,12 @@ def simple_fastapi_app() -> FastAPI:
version="0.1.0",
)
items = [
Item(id=1, name="Item 1", price=10.0, tags=["tag1", "tag2"], description="Item 1 description"),
Item(id=2, name="Item 2", price=20.0, tags=["tag2", "tag3"]),
Item(id=3, name="Item 3", price=30.0, tags=["tag3", "tag4"], description="Item 3 description"),
]
@app.get("/items/", response_model=List[Item], tags=["items"], operation_id="list_items")
async def list_items(
skip: int = Query(0, description="Number of items to skip"),
@@ -21,11 +27,7 @@ def simple_fastapi_app() -> FastAPI:
sort_by: Optional[str] = Query(None, description="Field to sort by"),
):
"""List all items with pagination and sorting options."""
return [
Item(id=1, name="Item 1", price=10.0, tags=["tag1", "tag2"], description="Item 1 description"),
Item(id=2, name="Item 2", price=20.0, tags=["tag2", "tag3"]),
Item(id=3, name="Item 3", price=30.0, tags=["tag3", "tag4"], description="Item 3 description"),
]
return items[skip : skip + limit]
@app.get("/items/{item_id}", response_model=Item, tags=["items"], operation_id="get_item")
async def read_item(
@@ -33,11 +35,15 @@ def simple_fastapi_app() -> FastAPI:
include_details: bool = Query(False, description="Include additional details"),
):
"""Get a specific item by its ID with optional details."""
return Item(id=item_id, name="Test Item", price=10.0, tags=["tag1", "tag2"])
found_item = next((item for item in items if item.id == item_id), None)
if found_item is None:
raise HTTPException(status_code=404, detail="Item not found")
return found_item
@app.post("/items/", response_model=Item, tags=["items"], operation_id="create_item")
async def create_item(item: Item = Body(..., description="The item to create")):
"""Create a new item in the database."""
items.append(item)
return item
@app.put("/items/{item_id}", response_model=Item, tags=["items"], operation_id="update_item")
@@ -54,4 +60,9 @@ def simple_fastapi_app() -> FastAPI:
"""Delete an item from the database."""
return None
@app.get("/error", status_code=200, tags=["error"], operation_id="raise_error")
async def raise_error():
"""Fail on purpose and cause a 500 error."""
raise Exception("This is a test error")
return app

View File

@@ -0,0 +1,218 @@
import json
import pytest
import mcp.types as types
from mcp.server.lowlevel import Server
from mcp.shared.memory import create_connected_server_and_client_session
from fastapi import FastAPI
from fastapi_mcp import FastApiMCP
from fastapi_mcp.utils.testing import FastAPITestClient
from .fixtures.types import Product, Customer, OrderResponse
@pytest.fixture
def lowlevel_server_complex_app(complex_fastapi_app: FastAPI) -> Server:
mcp = FastApiMCP(
complex_fastapi_app,
name="Test MCP Server",
description="Test description",
base_url="",
http_client=FastAPITestClient(complex_fastapi_app),
)
return mcp.server
@pytest.mark.asyncio
async def test_list_tools(lowlevel_server_complex_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
tools_result = await client_session.list_tools()
assert len(tools_result.tools) > 0
tool_names = [tool.name for tool in tools_result.tools]
expected_operations = ["list_products", "get_product", "create_order", "get_customer"]
for op in expected_operations:
assert op in tool_names
@pytest.mark.asyncio
async def test_call_tool_list_products_default(lowlevel_server_complex_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool("list_products", {})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert "items" in result
assert result["total"] == 1
assert result["page"] == 1
assert len(result["items"]) == 1
@pytest.mark.asyncio
async def test_call_tool_list_products_with_filters(lowlevel_server_complex_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool(
"list_products",
{"category": "electronics", "min_price": 10.0, "page": 1, "size": 10, "in_stock_only": True},
)
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert "items" in result
assert result["page"] == 1
assert result["size"] == 10
@pytest.mark.asyncio
async def test_call_tool_get_product(lowlevel_server_complex_app: Server, example_product: Product):
product_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool("get_product", {"product_id": product_id})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["id"] == product_id
assert "name" in result
assert "price" in result
assert "description" in result
@pytest.mark.asyncio
async def test_call_tool_get_product_with_options(lowlevel_server_complex_app: Server):
product_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool(
"get_product", {"product_id": product_id, "include_unavailable": True}
)
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["id"] == product_id
@pytest.mark.asyncio
async def test_call_tool_create_order(lowlevel_server_complex_app: Server, example_order_response: OrderResponse):
customer_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
product_id = "123e4567-e89b-12d3-a456-426614174001" # Valid UUID format
shipping_address_id = "123e4567-e89b-12d3-a456-426614174002" # Valid UUID format
order_request = {
"customer_id": customer_id,
"items": [{"product_id": product_id, "quantity": 2, "unit_price": 29.99, "total": 59.98}],
"shipping_address_id": shipping_address_id,
"payment_method": "credit_card",
}
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool("create_order", order_request)
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["customer_id"] == customer_id
assert "id" in result
assert "status" in result
assert "items" in result
assert len(result["items"]) > 0
@pytest.mark.asyncio
async def test_call_tool_create_order_validation_error(lowlevel_server_complex_app: Server):
# Missing required fields
order_request = {
# Missing customer_id
"items": [],
# Missing shipping_address_id
"payment_method": "credit_card",
}
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool("create_order", order_request)
assert response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
assert "422" in text_content.text or "validation" in text_content.text.lower()
@pytest.mark.asyncio
async def test_call_tool_get_customer(lowlevel_server_complex_app: Server, example_customer: Customer):
customer_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool("get_customer", {"customer_id": customer_id})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["id"] == customer_id
assert "full_name" in result
assert "email" in result
@pytest.mark.asyncio
async def test_call_tool_get_customer_with_options(lowlevel_server_complex_app: Server):
customer_id = "123e4567-e89b-12d3-a456-426614174000" # Valid UUID format
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
response = await client_session.call_tool(
"get_customer",
{
"customer_id": customer_id,
"include_orders": True,
"include_payment_methods": True,
"fields": ["full_name", "email", "orders"],
},
)
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["id"] == customer_id
@pytest.mark.asyncio
async def test_error_handling_missing_parameter(lowlevel_server_complex_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_complex_app) as client_session:
# Missing required product_id parameter
response = await client_session.call_tool("get_product", {})
assert response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
assert (
"422" in text_content.text
or "parameter" in text_content.text.lower()
or "field" in text_content.text.lower()
)

View File

@@ -0,0 +1,233 @@
import json
import pytest
import mcp.types as types
from mcp.server.lowlevel import Server
from mcp.shared.memory import create_connected_server_and_client_session
from fastapi import FastAPI
from fastapi_mcp import FastApiMCP
from fastapi_mcp.utils.testing import FastAPITestClient
from .fixtures.types import Item
@pytest.fixture
def lowlevel_server_simple_app(simple_fastapi_app: FastAPI) -> Server:
mcp = FastApiMCP(
simple_fastapi_app,
name="Test MCP Server",
description="Test description",
base_url="",
http_client=FastAPITestClient(simple_fastapi_app),
)
return mcp.server
@pytest.mark.asyncio
async def test_list_tools(lowlevel_server_simple_app: Server):
"""Test listing tools via direct MCP connection."""
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
tools_result = await client_session.list_tools()
assert len(tools_result.tools) > 0
tool_names = [tool.name for tool in tools_result.tools]
expected_operations = ["list_items", "get_item", "create_item", "update_item", "delete_item", "raise_error"]
for op in expected_operations:
assert op in tool_names
@pytest.mark.asyncio
async def test_call_tool_get_item_1(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("get_item", {"item_id": 1})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result: dict = json.loads(text_content.text)
parsed_result = Item(**result)
assert parsed_result.id == 1
assert parsed_result.name == "Item 1"
assert parsed_result.price == 10.0
assert parsed_result.tags == ["tag1", "tag2"]
@pytest.mark.asyncio
async def test_call_tool_get_item_2(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("get_item", {"item_id": 2})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result: dict = json.loads(text_content.text)
parsed_result = Item(**result)
assert parsed_result.id == 2
assert parsed_result.name == "Item 2"
assert parsed_result.price == 20.0
assert parsed_result.tags == ["tag2", "tag3"]
@pytest.mark.asyncio
async def test_call_tool_raise_error(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("raise_error", {})
assert response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
assert "500" in text_content.text
assert "internal server error" in text_content.text.lower()
@pytest.mark.asyncio
async def test_error_handling(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("get_item", {})
assert response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
assert "item_id" in text_content.text.lower() or "missing" in text_content.text.lower()
assert "422" in text_content.text, "Expected a 422 status to appear in the response text"
@pytest.mark.asyncio
async def test_complex_tool_arguments(lowlevel_server_simple_app: Server):
test_item = {
"id": 42,
"name": "Test Item",
"description": "A test item for MCP",
"price": 9.99,
"tags": ["test", "mcp"],
}
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("create_item", test_item)
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["id"] == test_item["id"]
assert result["name"] == test_item["name"]
assert result["price"] == test_item["price"]
assert result["tags"] == test_item["tags"]
@pytest.mark.asyncio
async def test_call_tool_list_items_default(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("list_items", {})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
results = json.loads(text_content.text)
assert len(results) == 3 # Default should return all three items with default pagination
# Check first item matches expected data
item = results[0]
assert item["id"] == 1
assert item["name"] == "Item 1"
@pytest.mark.asyncio
async def test_call_tool_list_items_with_pagination(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("list_items", {"skip": 1, "limit": 1})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
results = json.loads(text_content.text)
assert len(results) == 1
# Should be the second item in the list (after skipping the first)
item = results[0]
assert item["id"] == 2
assert item["name"] == "Item 2"
@pytest.mark.asyncio
async def test_call_tool_get_item_not_found(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("get_item", {"item_id": 999})
assert response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
assert "404" in text_content.text
assert "not found" in text_content.text.lower()
@pytest.mark.asyncio
async def test_call_tool_update_item(lowlevel_server_simple_app: Server):
test_update = {
"item_id": 3,
"id": 3,
"name": "Updated Item 3",
"description": "Updated description",
"price": 35.99,
"tags": ["updated", "modified"],
}
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("update_item", test_update)
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result = json.loads(text_content.text)
assert result["id"] == test_update["item_id"]
assert result["name"] == test_update["name"]
assert result["description"] == test_update["description"]
assert result["price"] == test_update["price"]
assert result["tags"] == test_update["tags"]
@pytest.mark.asyncio
async def test_call_tool_delete_item(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("delete_item", {"item_id": 3})
assert not response.isError
# The endpoint returns 204 No Content, so we expect an empty response
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
assert (
text_content.text.strip() == "{}" or text_content.text.strip() == "null" or text_content.text.strip() == ""
)
@pytest.mark.asyncio
async def test_call_tool_get_item_with_details(lowlevel_server_simple_app: Server):
async with create_connected_server_and_client_session(lowlevel_server_simple_app) as client_session:
response = await client_session.call_tool("get_item", {"item_id": 1, "include_details": True})
assert not response.isError
assert len(response.content) > 0
text_content = next(c for c in response.content if isinstance(c, types.TextContent))
result: dict = json.loads(text_content.text)
parsed_result = Item(**result)
assert parsed_result.id == 1
assert parsed_result.name == "Item 1"
assert parsed_result.price == 10.0
assert parsed_result.tags == ["tag1", "tag2"]
assert parsed_result.description == "Item 1 description"