mirror of
https://github.com/anthropics/claude-agent-sdk-python.git
synced 2025-10-06 01:00:03 +03:00
Use anyio, not asyncio, in src
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
"""Internal client implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
@@ -55,9 +54,10 @@ class InternalClient:
|
||||
await query.initialize()
|
||||
|
||||
# Stream input if it's an AsyncIterable
|
||||
if isinstance(prompt, AsyncIterable):
|
||||
if isinstance(prompt, AsyncIterable) and query._tg:
|
||||
# Start streaming in background
|
||||
asyncio.create_task(query.stream_input(prompt))
|
||||
# Create a task that will run in the background
|
||||
query._tg.start_soon(query.stream_input, prompt)
|
||||
# For string prompts, the prompt is already passed via CLI args
|
||||
|
||||
# Yield parsed messages
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Query class for handling bidirectional control protocol."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -8,6 +7,8 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from ..types import (
|
||||
SDKControlPermissionRequest,
|
||||
SDKControlRequest,
|
||||
@@ -54,14 +55,17 @@ class Query:
|
||||
self.hooks = hooks or {}
|
||||
|
||||
# Control protocol state
|
||||
self.pending_control_responses: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
||||
self.pending_control_responses: dict[str, anyio.Event] = {}
|
||||
self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
|
||||
self.hook_callbacks: dict[str, Callable[..., Any]] = {}
|
||||
self.next_callback_id = 0
|
||||
self._request_counter = 0
|
||||
|
||||
# Message stream
|
||||
self._message_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
self._read_task: asyncio.Task[None] | None = None
|
||||
self._message_send, self._message_receive = anyio.create_memory_object_stream[
|
||||
dict[str, Any]
|
||||
](max_buffer_size=100)
|
||||
self._tg: anyio.abc.TaskGroup | None = None
|
||||
self._initialized = False
|
||||
self._closed = False
|
||||
self._initialization_result: dict[str, Any] | None = None
|
||||
@@ -108,8 +112,10 @@ class Query:
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start reading messages from transport."""
|
||||
if self._read_task is None:
|
||||
self._read_task = asyncio.create_task(self._read_messages())
|
||||
if self._tg is None:
|
||||
self._tg = anyio.create_task_group()
|
||||
await self._tg.__aenter__()
|
||||
self._tg.start_soon(self._read_messages)
|
||||
|
||||
async def _read_messages(self) -> None:
|
||||
"""Read messages from transport and route them."""
|
||||
@@ -125,20 +131,22 @@ class Query:
|
||||
response = message.get("response", {})
|
||||
request_id = response.get("request_id")
|
||||
if request_id in self.pending_control_responses:
|
||||
future = self.pending_control_responses.pop(request_id)
|
||||
event = self.pending_control_responses[request_id]
|
||||
if response.get("subtype") == "error":
|
||||
future.set_exception(
|
||||
Exception(response.get("error", "Unknown error"))
|
||||
self.pending_control_results[request_id] = Exception(
|
||||
response.get("error", "Unknown error")
|
||||
)
|
||||
else:
|
||||
future.set_result(response)
|
||||
self.pending_control_results[request_id] = response
|
||||
event.set()
|
||||
continue
|
||||
|
||||
elif msg_type == "control_request":
|
||||
# Handle incoming control requests from CLI
|
||||
# Cast message to SDKControlRequest for type safety
|
||||
request: SDKControlRequest = message # type: ignore[assignment]
|
||||
asyncio.create_task(self._handle_control_request(request))
|
||||
if self._tg:
|
||||
self._tg.start_soon(self._handle_control_request, request)
|
||||
continue
|
||||
|
||||
elif msg_type == "control_cancel_request":
|
||||
@@ -146,20 +154,20 @@ class Query:
|
||||
# TODO: Implement cancellation support
|
||||
continue
|
||||
|
||||
# Regular SDK messages go to the queue
|
||||
await self._message_queue.put(message)
|
||||
# Regular SDK messages go to the stream
|
||||
await self._message_send.send(message)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# Task was cancelled - this is expected behavior
|
||||
logger.debug("Read task cancelled")
|
||||
raise # Re-raise to properly handle cancellation
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in message reader: {e}")
|
||||
# Put error in queue so iterators can handle it
|
||||
await self._message_queue.put({"type": "error", "error": str(e)})
|
||||
# Put error in stream so iterators can handle it
|
||||
await self._message_send.send({"type": "error", "error": str(e)})
|
||||
finally:
|
||||
# Always signal end of stream
|
||||
await self._message_queue.put({"type": "end"})
|
||||
await self._message_send.send({"type": "end"})
|
||||
|
||||
async def _handle_control_request(self, request: SDKControlRequest) -> None:
|
||||
"""Handle incoming control request from CLI."""
|
||||
@@ -234,9 +242,9 @@ class Query:
|
||||
self._request_counter += 1
|
||||
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
|
||||
|
||||
# Create future for response
|
||||
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
|
||||
self.pending_control_responses[request_id] = future
|
||||
# Create event for response
|
||||
event = anyio.Event()
|
||||
self.pending_control_responses[request_id] = event
|
||||
|
||||
# Build and send request
|
||||
control_request = {
|
||||
@@ -249,11 +257,20 @@ class Query:
|
||||
|
||||
# Wait for response
|
||||
try:
|
||||
response = await asyncio.wait_for(future, timeout=60.0)
|
||||
result = response.get("response", {})
|
||||
return result if isinstance(result, dict) else {}
|
||||
except asyncio.TimeoutError as e:
|
||||
with anyio.fail_after(60.0):
|
||||
await event.wait()
|
||||
|
||||
result = self.pending_control_results.pop(request_id)
|
||||
self.pending_control_responses.pop(request_id, None)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
|
||||
response_data = result.get("response", {})
|
||||
return response_data if isinstance(response_data, dict) else {}
|
||||
except TimeoutError as e:
|
||||
self.pending_control_responses.pop(request_id, None)
|
||||
self.pending_control_results.pop(request_id, None)
|
||||
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
@@ -283,25 +300,24 @@ class Query:
|
||||
|
||||
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Receive SDK messages (not control messages)."""
|
||||
while not self._closed:
|
||||
message = await self._message_queue.get()
|
||||
async with self._message_receive:
|
||||
async for message in self._message_receive:
|
||||
# Check for special messages
|
||||
if message.get("type") == "end":
|
||||
break
|
||||
elif message.get("type") == "error":
|
||||
raise Exception(message.get("error", "Unknown error"))
|
||||
|
||||
# Check for special messages
|
||||
if message.get("type") == "end":
|
||||
break
|
||||
elif message.get("type") == "error":
|
||||
raise Exception(message.get("error", "Unknown error"))
|
||||
|
||||
yield message
|
||||
yield message
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the query and transport."""
|
||||
self._closed = True
|
||||
if self._read_task and not self._read_task.done():
|
||||
self._read_task.cancel()
|
||||
# Wait for task to complete cancellation
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self._read_task
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
# Wait for task group to complete cancellation
|
||||
with suppress(anyio.get_cancelled_exc_class()):
|
||||
await self._tg.__aexit__(None, None, None)
|
||||
await self.transport.close()
|
||||
|
||||
# Make Query an async iterator
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Claude SDK Client for interacting with Claude Code."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
@@ -138,8 +137,8 @@ class ClaudeSDKClient:
|
||||
await self._query.initialize()
|
||||
|
||||
# If we have an initial prompt stream, start streaming it
|
||||
if prompt is not None and isinstance(prompt, AsyncIterable):
|
||||
asyncio.create_task(self._query.stream_input(prompt))
|
||||
if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg:
|
||||
self._query._tg.start_soon(self._query.stream_input, prompt)
|
||||
|
||||
async def receive_messages(self) -> AsyncIterator[Message]:
|
||||
"""Receive all messages from Claude."""
|
||||
|
||||
@@ -194,18 +194,19 @@ class SDKControlRequest(TypedDict):
|
||||
| SDKControlMcpMessageRequest
|
||||
)
|
||||
|
||||
|
||||
class ControlResponse(TypedDict):
|
||||
subtype: Literal['success']
|
||||
request_id: str
|
||||
response: dict[str, Any] | None
|
||||
subtype: Literal["success"]
|
||||
request_id: str
|
||||
response: dict[str, Any] | None
|
||||
|
||||
|
||||
class ControlErrorResponse(TypedDict):
|
||||
subtype: Literal['error']
|
||||
request_id: str
|
||||
error: str
|
||||
subtype: Literal["error"]
|
||||
request_id: str
|
||||
error: str
|
||||
|
||||
|
||||
class SDKControlResponse(TypedDict):
|
||||
type: Literal['control_response']
|
||||
type: Literal["control_response"]
|
||||
response: ControlResponse | ControlErrorResponse
|
||||
|
||||
@@ -512,15 +512,19 @@ class TestClaudeSDKClientStreaming:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
|
||||
if (
|
||||
msg.get("type") == "control_request"
|
||||
and msg.get("request", {}).get("subtype")
|
||||
== "initialize"
|
||||
):
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default"
|
||||
}
|
||||
"output_style": "default",
|
||||
},
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
@@ -592,9 +596,35 @@ while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
stdin_messages.append(line.strip())
|
||||
|
||||
# Verify we got 2 messages
|
||||
try:
|
||||
msg = json.loads(line.strip())
|
||||
# Handle control requests
|
||||
if msg.get("type") == "control_request":
|
||||
request_id = msg.get("request_id")
|
||||
request = msg.get("request", {})
|
||||
|
||||
# Send control response for initialize
|
||||
if request.get("subtype") == "initialize":
|
||||
response = {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {
|
||||
"commands": [],
|
||||
"output_style": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
print(json.dumps(response))
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
stdin_messages.append(line.strip())
|
||||
except:
|
||||
stdin_messages.append(line.strip())
|
||||
|
||||
# Verify we got 2 user messages
|
||||
assert len(stdin_messages) == 2
|
||||
assert '"First"' in stdin_messages[0]
|
||||
assert '"Second"' in stdin_messages[1]
|
||||
@@ -674,7 +704,7 @@ class TestClaudeSDKClientEdgeCases:
|
||||
# Create a new mock transport for each call
|
||||
mock_transport_class.side_effect = [
|
||||
create_mock_transport(),
|
||||
create_mock_transport()
|
||||
create_mock_transport(),
|
||||
]
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
@@ -736,15 +766,19 @@ class TestClaudeSDKClientEdgeCases:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
|
||||
if (
|
||||
msg.get("type") == "control_request"
|
||||
and msg.get("request", {}).get("subtype")
|
||||
== "initialize"
|
||||
):
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default"
|
||||
}
|
||||
"output_style": "default",
|
||||
},
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
|
||||
Reference in New Issue
Block a user