mirror of
https://github.com/anthropics/claude-agent-sdk-python.git
synced 2025-10-06 01:00:03 +03:00
mypy
This commit is contained in:
@@ -57,7 +57,7 @@ class SdkMcpTool(Generic[T]):
|
||||
|
||||
def tool(
|
||||
name: str, description: str, input_schema: type | dict[str, Any]
|
||||
) -> Callable[[Callable[[Any], Awaitable[dict[str, Any]]]], SdkMcpTool]:
|
||||
) -> Callable[[Callable[[Any], Awaitable[dict[str, Any]]]], SdkMcpTool[Any]]:
|
||||
"""Decorator for defining MCP tools with type safety.
|
||||
|
||||
Creates a tool that can be used with SDK MCP servers. The tool runs
|
||||
@@ -105,7 +105,9 @@ def tool(
|
||||
- Errors can be indicated by including "is_error": True in the response
|
||||
"""
|
||||
|
||||
def decorator(handler: Callable[[Any], Awaitable[dict[str, Any]]]) -> SdkMcpTool:
|
||||
def decorator(
|
||||
handler: Callable[[Any], Awaitable[dict[str, Any]]],
|
||||
) -> SdkMcpTool[Any]:
|
||||
return SdkMcpTool(
|
||||
name=name,
|
||||
description=description,
|
||||
@@ -117,7 +119,7 @@ def tool(
|
||||
|
||||
|
||||
def create_sdk_mcp_server(
|
||||
name: str, version: str = "1.0.0", tools: list[SdkMcpTool] | None = None
|
||||
name: str, version: str = "1.0.0", tools: list[SdkMcpTool[Any]] | None = None
|
||||
) -> McpSdkServerConfig:
|
||||
"""Create an in-process MCP server that runs within your Python application.
|
||||
|
||||
@@ -200,7 +202,7 @@ def create_sdk_mcp_server(
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
# Register list_tools handler to expose available tools
|
||||
@server.list_tools()
|
||||
@server.list_tools() # type: ignore[no-untyped-call,misc]
|
||||
async def list_tools() -> list[Tool]:
|
||||
"""Return the list of available tools."""
|
||||
tool_list = []
|
||||
@@ -246,8 +248,8 @@ def create_sdk_mcp_server(
|
||||
return tool_list
|
||||
|
||||
# Register call_tool handler to execute tools
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> Any:
|
||||
@server.call_tool() # type: ignore[misc]
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Execute a tool by name with given arguments."""
|
||||
if name not in tool_map:
|
||||
raise ValueError(f"Tool '{name}' not found")
|
||||
|
||||
@@ -20,10 +20,10 @@ class InternalClient:
|
||||
"""Initialize the internal client."""
|
||||
|
||||
def _convert_hooks_to_internal_format(
|
||||
self, hooks: dict[str, list]
|
||||
self, hooks: dict[str, list[Any]]
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Convert HookMatcher format to internal Query format."""
|
||||
internal_hooks = {}
|
||||
internal_hooks: dict[str, list[dict[str, Any]]] = {}
|
||||
for event, matchers in hooks.items():
|
||||
internal_hooks[event] = []
|
||||
for matcher in matchers:
|
||||
@@ -57,7 +57,7 @@ class InternalClient:
|
||||
if options.mcp_servers and isinstance(options.mcp_servers, dict):
|
||||
for name, config in options.mcp_servers.items():
|
||||
if isinstance(config, dict) and config.get("type") == "sdk":
|
||||
sdk_mcp_servers[name] = config["instance"]
|
||||
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
|
||||
|
||||
# Create Query to handle control protocol
|
||||
is_streaming = not isinstance(prompt, str)
|
||||
|
||||
@@ -47,7 +47,8 @@ class Query:
|
||||
transport: Transport,
|
||||
is_streaming_mode: bool,
|
||||
can_use_tool: Callable[
|
||||
[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]
|
||||
[str, dict[str, Any], ToolPermissionContext],
|
||||
Awaitable[PermissionResultAllow | PermissionResultDeny],
|
||||
]
|
||||
| None = None,
|
||||
hooks: dict[str, list[dict[str, Any]]] | None = None,
|
||||
@@ -190,7 +191,7 @@ class Query:
|
||||
subtype = request_data["subtype"]
|
||||
|
||||
try:
|
||||
response_data = {}
|
||||
response_data: dict[str, Any] = {}
|
||||
|
||||
if subtype == "can_use_tool":
|
||||
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
|
||||
@@ -200,7 +201,8 @@ class Query:
|
||||
|
||||
context = ToolPermissionContext(
|
||||
signal=None, # TODO: Add abort signal support
|
||||
suggestions=permission_request.get("permission_suggestions", []),
|
||||
suggestions=permission_request.get("permission_suggestions", [])
|
||||
or [],
|
||||
)
|
||||
|
||||
response = await self.can_use_tool(
|
||||
@@ -237,7 +239,7 @@ class Query:
|
||||
{"signal": None}, # TODO: Add abort signal support
|
||||
)
|
||||
|
||||
elif subtype == "mcp_request":
|
||||
elif subtype == "mcp_message":
|
||||
# Handle SDK MCP request
|
||||
server_name = request_data.get("server_name")
|
||||
mcp_message = request_data.get("message")
|
||||
@@ -245,6 +247,9 @@ class Query:
|
||||
if not server_name or not mcp_message:
|
||||
raise Exception("Missing server_name or message for MCP request")
|
||||
|
||||
# Type narrowing - we've verified these are not None above
|
||||
assert isinstance(server_name, str)
|
||||
assert isinstance(mcp_message, dict)
|
||||
response_data = await self._handle_sdk_mcp_request(
|
||||
server_name, mcp_message
|
||||
)
|
||||
@@ -315,7 +320,9 @@ class Query:
|
||||
self.pending_control_results.pop(request_id, None)
|
||||
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
|
||||
|
||||
async def _handle_sdk_mcp_request(self, server_name: str, message: dict) -> dict:
|
||||
async def _handle_sdk_mcp_request(
|
||||
self, server_name: str, message: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Handle an MCP request for an SDK server.
|
||||
|
||||
This acts as a bridge between JSONRPC messages from the CLI
|
||||
@@ -360,11 +367,11 @@ class Query:
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema.model_dump()
|
||||
"inputSchema": tool.inputSchema.model_dump() # type: ignore[union-attr]
|
||||
if tool.inputSchema
|
||||
else {},
|
||||
}
|
||||
for tool in result.root.tools
|
||||
for tool in result.root.tools # type: ignore[union-attr]
|
||||
]
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
@@ -373,7 +380,7 @@ class Query:
|
||||
}
|
||||
|
||||
elif method == "tools/call":
|
||||
request = CallToolRequest(
|
||||
call_request = CallToolRequest(
|
||||
method=method,
|
||||
params=CallToolRequestParams(
|
||||
name=params.get("name"), arguments=params.get("arguments", {})
|
||||
@@ -381,10 +388,10 @@ class Query:
|
||||
)
|
||||
handler = server.request_handlers.get(CallToolRequest)
|
||||
if handler:
|
||||
result = await handler(request)
|
||||
result = await handler(call_request)
|
||||
# Convert MCP result to JSONRPC response
|
||||
content = []
|
||||
for item in result.root.content:
|
||||
for item in result.root.content: # type: ignore[union-attr]
|
||||
if hasattr(item, "text"):
|
||||
content.append({"type": "text", "text": item.text})
|
||||
elif hasattr(item, "data") and hasattr(item, "mimeType"):
|
||||
@@ -398,7 +405,7 @@ class Query:
|
||||
|
||||
response_data = {"content": content}
|
||||
if hasattr(result.root, "is_error") and result.root.is_error:
|
||||
response_data["is_error"] = True
|
||||
response_data["is_error"] = True # type: ignore[assignment]
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
|
||||
@@ -101,10 +101,10 @@ class ClaudeSDKClient:
|
||||
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"
|
||||
|
||||
def _convert_hooks_to_internal_format(
|
||||
self, hooks: dict[str, list]
|
||||
self, hooks: dict[str, list[Any]]
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Convert HookMatcher format to internal Query format."""
|
||||
internal_hooks = {}
|
||||
internal_hooks: dict[str, list[dict[str, Any]]] = {}
|
||||
for event, matchers in hooks.items():
|
||||
internal_hooks[event] = []
|
||||
for matcher in matchers:
|
||||
@@ -145,7 +145,7 @@ class ClaudeSDKClient:
|
||||
if self.options.mcp_servers and isinstance(self.options.mcp_servers, dict):
|
||||
for name, config in self.options.mcp_servers.items():
|
||||
if isinstance(config, dict) and config.get("type") == "sdk":
|
||||
sdk_mcp_servers[name] = config["instance"]
|
||||
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
|
||||
|
||||
# Create Query to handle control protocol
|
||||
self._query = Query(
|
||||
|
||||
@@ -5,10 +5,7 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
||||
|
||||
try:
|
||||
from typing import NotRequired # Python 3.11+
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired # For Python < 3.11 compatibility
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server import Server as McpServer
|
||||
|
||||
Reference in New Issue
Block a user