Use anyio, not asyncio, in src

This commit is contained in:
Dickson Tsai
2025-09-02 06:35:00 +09:00
parent dd1ccf8b92
commit 05e932d2d1
5 changed files with 110 additions and 60 deletions

View File

@@ -1,6 +1,5 @@
"""Internal client implementation."""
import asyncio
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
@@ -55,9 +54,10 @@ class InternalClient:
await query.initialize()
# Stream input if it's an AsyncIterable
if isinstance(prompt, AsyncIterable):
if isinstance(prompt, AsyncIterable) and query._tg:
# Start streaming in background
asyncio.create_task(query.stream_input(prompt))
# Create a task that will run in the background
query._tg.start_soon(query.stream_input, prompt)
# For string prompts, the prompt is already passed via CLI args
# Yield parsed messages

View File

@@ -1,6 +1,5 @@
"""Query class for handling bidirectional control protocol."""
import asyncio
import json
import logging
import os
@@ -8,6 +7,8 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from contextlib import suppress
from typing import Any
import anyio
from ..types import (
SDKControlPermissionRequest,
SDKControlRequest,
@@ -54,14 +55,17 @@ class Query:
self.hooks = hooks or {}
# Control protocol state
self.pending_control_responses: dict[str, asyncio.Future[dict[str, Any]]] = {}
self.pending_control_responses: dict[str, anyio.Event] = {}
self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
self.hook_callbacks: dict[str, Callable[..., Any]] = {}
self.next_callback_id = 0
self._request_counter = 0
# Message stream
self._message_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._read_task: asyncio.Task[None] | None = None
self._message_send, self._message_receive = anyio.create_memory_object_stream[
dict[str, Any]
](max_buffer_size=100)
self._tg: anyio.abc.TaskGroup | None = None
self._initialized = False
self._closed = False
self._initialization_result: dict[str, Any] | None = None
@@ -108,8 +112,10 @@ class Query:
async def start(self) -> None:
"""Start reading messages from transport."""
if self._read_task is None:
self._read_task = asyncio.create_task(self._read_messages())
if self._tg is None:
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
@@ -125,20 +131,22 @@ class Query:
response = message.get("response", {})
request_id = response.get("request_id")
if request_id in self.pending_control_responses:
future = self.pending_control_responses.pop(request_id)
event = self.pending_control_responses[request_id]
if response.get("subtype") == "error":
future.set_exception(
Exception(response.get("error", "Unknown error"))
self.pending_control_results[request_id] = Exception(
response.get("error", "Unknown error")
)
else:
future.set_result(response)
self.pending_control_results[request_id] = response
event.set()
continue
elif msg_type == "control_request":
# Handle incoming control requests from CLI
# Cast message to SDKControlRequest for type safety
request: SDKControlRequest = message # type: ignore[assignment]
asyncio.create_task(self._handle_control_request(request))
if self._tg:
self._tg.start_soon(self._handle_control_request, request)
continue
elif msg_type == "control_cancel_request":
@@ -146,20 +154,20 @@ class Query:
# TODO: Implement cancellation support
continue
# Regular SDK messages go to the queue
await self._message_queue.put(message)
# Regular SDK messages go to the stream
await self._message_send.send(message)
except asyncio.CancelledError:
except anyio.get_cancelled_exc_class():
# Task was cancelled - this is expected behavior
logger.debug("Read task cancelled")
raise # Re-raise to properly handle cancellation
except Exception as e:
logger.error(f"Fatal error in message reader: {e}")
# Put error in queue so iterators can handle it
await self._message_queue.put({"type": "error", "error": str(e)})
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
finally:
# Always signal end of stream
await self._message_queue.put({"type": "end"})
await self._message_send.send({"type": "end"})
async def _handle_control_request(self, request: SDKControlRequest) -> None:
"""Handle incoming control request from CLI."""
@@ -234,9 +242,9 @@ class Query:
self._request_counter += 1
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
# Create future for response
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
self.pending_control_responses[request_id] = future
# Create event for response
event = anyio.Event()
self.pending_control_responses[request_id] = event
# Build and send request
control_request = {
@@ -249,11 +257,20 @@ class Query:
# Wait for response
try:
response = await asyncio.wait_for(future, timeout=60.0)
result = response.get("response", {})
return result if isinstance(result, dict) else {}
except asyncio.TimeoutError as e:
with anyio.fail_after(60.0):
await event.wait()
result = self.pending_control_results.pop(request_id)
self.pending_control_responses.pop(request_id, None)
if isinstance(result, Exception):
raise result
response_data = result.get("response", {})
return response_data if isinstance(response_data, dict) else {}
except TimeoutError as e:
self.pending_control_responses.pop(request_id, None)
self.pending_control_results.pop(request_id, None)
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
async def interrupt(self) -> None:
@@ -283,25 +300,24 @@ class Query:
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Receive SDK messages (not control messages)."""
while not self._closed:
message = await self._message_queue.get()
async with self._message_receive:
async for message in self._message_receive:
# Check for special messages
if message.get("type") == "end":
break
elif message.get("type") == "error":
raise Exception(message.get("error", "Unknown error"))
# Check for special messages
if message.get("type") == "end":
break
elif message.get("type") == "error":
raise Exception(message.get("error", "Unknown error"))
yield message
yield message
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._read_task and not self._read_task.done():
self._read_task.cancel()
# Wait for task to complete cancellation
with suppress(asyncio.CancelledError):
await self._read_task
if self._tg:
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
await self.transport.close()
# Make Query an async iterator

View File

@@ -1,6 +1,5 @@
"""Claude SDK Client for interacting with Claude Code."""
import asyncio
import json
import os
from collections.abc import AsyncIterable, AsyncIterator
@@ -138,8 +137,8 @@ class ClaudeSDKClient:
await self._query.initialize()
# If we have an initial prompt stream, start streaming it
if prompt is not None and isinstance(prompt, AsyncIterable):
asyncio.create_task(self._query.stream_input(prompt))
if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg:
self._query._tg.start_soon(self._query.stream_input, prompt)
async def receive_messages(self) -> AsyncIterator[Message]:
"""Receive all messages from Claude."""

View File

@@ -194,18 +194,19 @@ class SDKControlRequest(TypedDict):
| SDKControlMcpMessageRequest
)
class ControlResponse(TypedDict):
subtype: Literal['success']
request_id: str
response: dict[str, Any] | None
subtype: Literal["success"]
request_id: str
response: dict[str, Any] | None
class ControlErrorResponse(TypedDict):
subtype: Literal['error']
request_id: str
error: str
subtype: Literal["error"]
request_id: str
error: str
class SDKControlResponse(TypedDict):
type: Literal['control_response']
type: Literal["control_response"]
response: ControlResponse | ControlErrorResponse

View File

@@ -512,15 +512,19 @@ class TestClaudeSDKClientStreaming:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default"
}
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
@@ -592,9 +596,35 @@ while True:
line = sys.stdin.readline()
if not line:
break
stdin_messages.append(line.strip())
# Verify we got 2 messages
try:
msg = json.loads(line.strip())
# Handle control requests
if msg.get("type") == "control_request":
request_id = msg.get("request_id")
request = msg.get("request", {})
# Send control response for initialize
if request.get("subtype") == "initialize":
response = {
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": {
"commands": [],
"output_style": "default"
}
}
}
print(json.dumps(response))
sys.stdout.flush()
else:
stdin_messages.append(line.strip())
except:
stdin_messages.append(line.strip())
# Verify we got 2 user messages
assert len(stdin_messages) == 2
assert '"First"' in stdin_messages[0]
assert '"Second"' in stdin_messages[1]
@@ -674,7 +704,7 @@ class TestClaudeSDKClientEdgeCases:
# Create a new mock transport for each call
mock_transport_class.side_effect = [
create_mock_transport(),
create_mock_transport()
create_mock_transport(),
]
client = ClaudeSDKClient()
@@ -736,15 +766,19 @@ class TestClaudeSDKClientEdgeCases:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "control_request" and msg.get("request", {}).get("subtype") == "initialize":
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default"
}
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):