This commit is contained in:
Ashwin Bhat
2025-09-04 12:59:15 -07:00
parent 7b0938a1cc
commit d68adecd44
5 changed files with 33 additions and 27 deletions

View File

@@ -57,7 +57,7 @@ class SdkMcpTool(Generic[T]):
def tool(
name: str, description: str, input_schema: type | dict[str, Any]
) -> Callable[[Callable[[Any], Awaitable[dict[str, Any]]]], SdkMcpTool]:
) -> Callable[[Callable[[Any], Awaitable[dict[str, Any]]]], SdkMcpTool[Any]]:
"""Decorator for defining MCP tools with type safety.
Creates a tool that can be used with SDK MCP servers. The tool runs
@@ -105,7 +105,9 @@ def tool(
- Errors can be indicated by including "is_error": True in the response
"""
def decorator(handler: Callable[[Any], Awaitable[dict[str, Any]]]) -> SdkMcpTool:
def decorator(
handler: Callable[[Any], Awaitable[dict[str, Any]]],
) -> SdkMcpTool[Any]:
return SdkMcpTool(
name=name,
description=description,
@@ -117,7 +119,7 @@ def tool(
def create_sdk_mcp_server(
name: str, version: str = "1.0.0", tools: list[SdkMcpTool] | None = None
name: str, version: str = "1.0.0", tools: list[SdkMcpTool[Any]] | None = None
) -> McpSdkServerConfig:
"""Create an in-process MCP server that runs within your Python application.
@@ -200,7 +202,7 @@ def create_sdk_mcp_server(
tool_map = {tool_def.name: tool_def for tool_def in tools}
# Register list_tools handler to expose available tools
@server.list_tools()
@server.list_tools() # type: ignore[no-untyped-call,misc]
async def list_tools() -> list[Tool]:
"""Return the list of available tools."""
tool_list = []
@@ -246,8 +248,8 @@ def create_sdk_mcp_server(
return tool_list
# Register call_tool handler to execute tools
@server.call_tool()
async def call_tool(name: str, arguments: dict) -> Any:
@server.call_tool() # type: ignore[misc]
async def call_tool(name: str, arguments: dict[str, Any]) -> Any:
"""Execute a tool by name with given arguments."""
if name not in tool_map:
raise ValueError(f"Tool '{name}' not found")

View File

@@ -20,10 +20,10 @@ class InternalClient:
"""Initialize the internal client."""
def _convert_hooks_to_internal_format(
self, hooks: dict[str, list]
self, hooks: dict[str, list[Any]]
) -> dict[str, list[dict[str, Any]]]:
"""Convert HookMatcher format to internal Query format."""
internal_hooks = {}
internal_hooks: dict[str, list[dict[str, Any]]] = {}
for event, matchers in hooks.items():
internal_hooks[event] = []
for matcher in matchers:
@@ -57,7 +57,7 @@ class InternalClient:
if options.mcp_servers and isinstance(options.mcp_servers, dict):
for name, config in options.mcp_servers.items():
if isinstance(config, dict) and config.get("type") == "sdk":
sdk_mcp_servers[name] = config["instance"]
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
# Create Query to handle control protocol
is_streaming = not isinstance(prompt, str)

View File

@@ -47,7 +47,8 @@ class Query:
transport: Transport,
is_streaming_mode: bool,
can_use_tool: Callable[
[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]
[str, dict[str, Any], ToolPermissionContext],
Awaitable[PermissionResultAllow | PermissionResultDeny],
]
| None = None,
hooks: dict[str, list[dict[str, Any]]] | None = None,
@@ -190,7 +191,7 @@ class Query:
subtype = request_data["subtype"]
try:
response_data = {}
response_data: dict[str, Any] = {}
if subtype == "can_use_tool":
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
@@ -200,7 +201,8 @@ class Query:
context = ToolPermissionContext(
signal=None, # TODO: Add abort signal support
suggestions=permission_request.get("permission_suggestions", []),
suggestions=permission_request.get("permission_suggestions", [])
or [],
)
response = await self.can_use_tool(
@@ -237,7 +239,7 @@ class Query:
{"signal": None}, # TODO: Add abort signal support
)
elif subtype == "mcp_request":
elif subtype == "mcp_message":
# Handle SDK MCP request
server_name = request_data.get("server_name")
mcp_message = request_data.get("message")
@@ -245,6 +247,9 @@ class Query:
if not server_name or not mcp_message:
raise Exception("Missing server_name or message for MCP request")
# Type narrowing - we've verified these are not None above
assert isinstance(server_name, str)
assert isinstance(mcp_message, dict)
response_data = await self._handle_sdk_mcp_request(
server_name, mcp_message
)
@@ -315,7 +320,9 @@ class Query:
self.pending_control_results.pop(request_id, None)
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict:
async def _handle_sdk_mcp_request(
self, server_name: str, message: dict[str, Any]
) -> dict[str, Any]:
"""Handle an MCP request for an SDK server.
This acts as a bridge between JSONRPC messages from the CLI
@@ -360,11 +367,11 @@ class Query:
{
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema.model_dump()
"inputSchema": tool.inputSchema.model_dump() # type: ignore[union-attr]
if tool.inputSchema
else {},
}
for tool in result.root.tools
for tool in result.root.tools # type: ignore[union-attr]
]
return {
"jsonrpc": "2.0",
@@ -373,7 +380,7 @@ class Query:
}
elif method == "tools/call":
request = CallToolRequest(
call_request = CallToolRequest(
method=method,
params=CallToolRequestParams(
name=params.get("name"), arguments=params.get("arguments", {})
@@ -381,10 +388,10 @@ class Query:
)
handler = server.request_handlers.get(CallToolRequest)
if handler:
result = await handler(request)
result = await handler(call_request)
# Convert MCP result to JSONRPC response
content = []
for item in result.root.content:
for item in result.root.content: # type: ignore[union-attr]
if hasattr(item, "text"):
content.append({"type": "text", "text": item.text})
elif hasattr(item, "data") and hasattr(item, "mimeType"):
@@ -398,7 +405,7 @@ class Query:
response_data = {"content": content}
if hasattr(result.root, "is_error") and result.root.is_error:
response_data["is_error"] = True
response_data["is_error"] = True # type: ignore[assignment]
return {
"jsonrpc": "2.0",

View File

@@ -101,10 +101,10 @@ class ClaudeSDKClient:
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"
def _convert_hooks_to_internal_format(
self, hooks: dict[str, list]
self, hooks: dict[str, list[Any]]
) -> dict[str, list[dict[str, Any]]]:
"""Convert HookMatcher format to internal Query format."""
internal_hooks = {}
internal_hooks: dict[str, list[dict[str, Any]]] = {}
for event, matchers in hooks.items():
internal_hooks[event] = []
for matcher in matchers:
@@ -145,7 +145,7 @@ class ClaudeSDKClient:
if self.options.mcp_servers and isinstance(self.options.mcp_servers, dict):
for name, config in self.options.mcp_servers.items():
if isinstance(config, dict) and config.get("type") == "sdk":
sdk_mcp_servers[name] = config["instance"]
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
# Create Query to handle control protocol
self._query = Query(

View File

@@ -5,10 +5,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, TypedDict
try:
from typing import NotRequired # Python 3.11+
except ImportError:
from typing_extensions import NotRequired # For Python < 3.11 compatibility
from typing_extensions import NotRequired
if TYPE_CHECKING:
from mcp.server import Server as McpServer