Implement control protocol support for Python SDK (#139)

## Summary

This PR implements control protocol support in the Python SDK, aligning
it with the TypeScript implementation pattern. The refactor introduces a
Query + Transport separation to enable bidirectional communication
between the SDK and CLI.

## Motivation

The previous Python SDK implementation used a high-level abstraction in
the Transport ABC (`send_request`/`receive_messages`) that couldn't
handle bidirectional communication. This prevented support for:
- Control messages from CLI to SDK that need responses
- Hooks implementation  
- Dynamic permission mode changes
- SDK MCP servers

## Changes

### Core Architecture Refactor

1. **New Query Class** (`src/claude_code_sdk/_internal/query.py`)
   - Manages control protocol on top of Transport
   - Handles control request/response routing
   - Manages initialization handshake with timeout
   - Supports hook callbacks and tool permission callbacks
   - Implements message streaming

2. **Refactored Transport ABC**
(`src/claude_code_sdk/_internal/transport/__init__.py`)
- Changed from high-level (`send_request`/`receive_messages`) to
low-level (`write`/`read_messages`) interface
   - Now handles raw I/O instead of protocol logic
   - Aligns with TypeScript ProcessTransport pattern

3. **Updated SubprocessCLITransport**
(`src/claude_code_sdk/_internal/transport/subprocess_cli.py`)
   - Simplified to focus on raw message streaming
   - Removed protocol logic (moved to Query)
   - Improved cleanup and error handling

4. **Enhanced ClaudeSDKClient** (`src/claude_code_sdk/client.py`)
   - Now uses Query for control protocol
   - Supports initialization messages
   - Better error handling for control protocol failures

### Control Protocol Features

- **Initialization handshake**: SDK sends initialize request, CLI
responds with supported commands
- **Control message types**: 
  - `initialize`: Establish bidirectional connection
  - `interrupt`: Cancel ongoing operations  
  - `set_permission_mode`: Change permission mode dynamically
- **Timeout handling**: 60-second timeout for initialization to handle
CLI versions without control support

### Examples

Updated `examples/streaming_mode.py` to demonstrate control protocol
initialization and error handling.

## Testing

- Tested with current CLI (no control protocol support yet) - gracefully
falls back
- Verified backward compatibility with existing `query()` function
- Tested initialization timeout handling
- Verified proper cleanup on errors

## Design Alignment

This implementation closely follows the TypeScript reference:
- `src/core/Query.ts` → `src/claude_code_sdk/_internal/query.py`
- `src/transport/ProcessTransport.ts` →
`src/claude_code_sdk/_internal/transport/subprocess_cli.py`
- `src/entrypoints/sdk.ts` → `src/claude_code_sdk/client.py`

## Next Steps

Once the CLI implements the control protocol handler, this will enable:
- Hooks support
- Dynamic permission mode changes
- SDK MCP servers
- Improved error recovery

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Ashwin Bhat <ashwin@anthropic.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Kashyap Murali <kashyap@anthropic.com>
This commit is contained in:
Dickson Tsai
2025-09-02 15:04:22 +09:00
committed by GitHub
parent 9a64bc3a64
commit 22fa9f473e
13 changed files with 1002 additions and 238 deletions

3
.gitignore vendored
View File

@@ -33,6 +33,7 @@ env/
*.swp
*.swo
*~
**/.DS_Store
# Testing
.tox/
@@ -46,4 +47,4 @@ htmlcov/
.mypy_cache/
.dmypy.json
dmypy.json
.pyre/
.pyre/

View File

@@ -340,6 +340,85 @@ async def example_bash_command():
print("\n")
async def example_control_protocol():
"""Demonstrate server info and interrupt capabilities."""
print("=== Control Protocol Example ===")
print("Shows server info retrieval and interrupt capability\n")
async with ClaudeSDKClient() as client:
# 1. Get server initialization info
print("1. Getting server info...")
server_info = await client.get_server_info()
if server_info:
print("✓ Server info retrieved successfully!")
print(f" - Available commands: {len(server_info.get('commands', []))}")
print(f" - Output style: {server_info.get('output_style', 'unknown')}")
# Show available output styles if present
styles = server_info.get('available_output_styles', [])
if styles:
print(f" - Available output styles: {', '.join(styles)}")
# Show a few example commands
commands = server_info.get('commands', [])[:5]
if commands:
print(" - Example commands:")
for cmd in commands:
if isinstance(cmd, dict):
print(f"{cmd.get('name', 'unknown')}")
else:
print("✗ No server info available (may not be in streaming mode)")
print("\n2. Testing interrupt capability...")
# Start a long-running task
print("User: Count from 1 to 20 slowly")
await client.query("Count from 1 to 20 slowly, pausing between each number")
# Start consuming messages in background to enable interrupt
messages = []
async def consume():
async for msg in client.receive_response():
messages.append(msg)
if isinstance(msg, AssistantMessage):
for block in msg.content:
if isinstance(block, TextBlock):
# Print first 50 chars to show progress
print(f"Claude: {block.text[:50]}...")
break
if isinstance(msg, ResultMessage):
break
consume_task = asyncio.create_task(consume())
# Wait a moment then interrupt
await asyncio.sleep(2)
print("\n[Sending interrupt after 2 seconds...]")
try:
await client.interrupt()
print("✓ Interrupt sent successfully")
except Exception as e:
print(f"✗ Interrupt failed: {e}")
# Wait for task to complete
with contextlib.suppress(asyncio.CancelledError):
await consume_task
# Send new query after interrupt
print("\nUser: Just say 'Hello!'")
await client.query("Just say 'Hello!'")
async for msg in client.receive_response():
if isinstance(msg, AssistantMessage):
for block in msg.content:
if isinstance(block, TextBlock):
print(f"Claude: {block.text}")
print("\n")
async def example_error_handling():
"""Demonstrate proper error handling."""
print("=== Error Handling Example ===")
@@ -350,8 +429,8 @@ async def example_error_handling():
await client.connect()
# Send a message that will take time to process
print("User: Run a bash sleep command for 60 seconds")
await client.query("Run a bash sleep command for 60 seconds")
print("User: Run a bash sleep command for 60 seconds not in the background")
await client.query("Run a bash sleep command for 60 seconds not in the background")
# Try to receive response with a short timeout
try:
@@ -397,6 +476,7 @@ async def main():
"with_options": example_with_options,
"async_iterable_prompt": example_async_iterable_prompt,
"bash_command": example_bash_command,
"control_protocol": example_control_protocol,
"error_handling": example_error_handling,
}

View File

@@ -8,6 +8,7 @@ from ..types import (
Message,
)
from .message_parser import parse_message
from .query import Query
from .transport import Transport
from .transport.subprocess_cli import SubprocessCLITransport
@@ -24,21 +25,44 @@ class InternalClient:
options: ClaudeCodeOptions,
transport: Transport | None = None,
) -> AsyncIterator[Message]:
"""Process a query through transport."""
"""Process a query through transport and Query."""
# Use provided transport or choose one based on configuration
# Use provided transport or create subprocess transport
if transport is not None:
chosen_transport = transport
else:
chosen_transport = SubprocessCLITransport(
prompt=prompt, options=options, close_stdin_after_prompt=True
)
chosen_transport = SubprocessCLITransport(prompt=prompt, options=options)
# Connect transport
await chosen_transport.connect()
# Create Query to handle control protocol
is_streaming = not isinstance(prompt, str)
query = Query(
transport=chosen_transport,
is_streaming_mode=is_streaming,
can_use_tool=None, # TODO: Add support for can_use_tool callback
hooks=None, # TODO: Add support for hooks
)
try:
await chosen_transport.connect()
# Start reading messages
await query.start()
async for data in chosen_transport.receive_messages():
# Initialize if streaming
if is_streaming:
await query.initialize()
# Stream input if it's an AsyncIterable
if isinstance(prompt, AsyncIterable) and query._tg:
# Start streaming in background
# Create a task that will run in the background
query._tg.start_soon(query.stream_input, prompt)
# For string prompts, the prompt is already passed via CLI args
# Yield parsed messages
async for data in query.receive_messages():
yield parse_message(data)
finally:
await chosen_transport.disconnect()
await query.close()

View File

@@ -0,0 +1,332 @@
"""Query class for handling bidirectional control protocol."""
import json
import logging
import os
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from contextlib import suppress
from typing import Any
import anyio
from ..types import (
SDKControlPermissionRequest,
SDKControlRequest,
SDKControlResponse,
SDKHookCallbackRequest,
)
from .transport import Transport
logger = logging.getLogger(__name__)
class Query:
"""Handles bidirectional control protocol on top of Transport.
This class manages:
- Control request/response routing
- Hook callbacks
- Tool permission callbacks
- Message streaming
- Initialization handshake
"""
def __init__(
self,
transport: Transport,
is_streaming_mode: bool,
can_use_tool: Callable[
[str, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]
]
| None = None,
hooks: dict[str, list[dict[str, Any]]] | None = None,
):
"""Initialize Query with transport and callbacks.
Args:
transport: Low-level transport for I/O
is_streaming_mode: Whether using streaming (bidirectional) mode
can_use_tool: Optional callback for tool permission requests
hooks: Optional hook configurations
"""
self.transport = transport
self.is_streaming_mode = is_streaming_mode
self.can_use_tool = can_use_tool
self.hooks = hooks or {}
# Control protocol state
self.pending_control_responses: dict[str, anyio.Event] = {}
self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
self.hook_callbacks: dict[str, Callable[..., Any]] = {}
self.next_callback_id = 0
self._request_counter = 0
# Message stream
self._message_send, self._message_receive = anyio.create_memory_object_stream[
dict[str, Any]
](max_buffer_size=100)
self._tg: anyio.abc.TaskGroup | None = None
self._initialized = False
self._closed = False
self._initialization_result: dict[str, Any] | None = None
async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.
Returns:
Initialize response with supported commands, or None if not streaming
"""
if not self.is_streaming_mode:
return None
# Build hooks configuration for initialization
hooks_config: dict[str, Any] = {}
if self.hooks:
for event, matchers in self.hooks.items():
if matchers:
hooks_config[event] = []
for matcher in matchers:
callback_ids = []
for callback in matcher.get("hooks", []):
callback_id = f"hook_{self.next_callback_id}"
self.next_callback_id += 1
self.hook_callbacks[callback_id] = callback
callback_ids.append(callback_id)
hooks_config[event].append(
{
"matcher": matcher.get("matcher"),
"hookCallbackIds": callback_ids,
}
)
# Send initialize request
request = {
"subtype": "initialize",
"hooks": hooks_config if hooks_config else None,
}
response = await self._send_control_request(request)
self._initialized = True
self._initialization_result = response # Store for later access
return response
async def start(self) -> None:
"""Start reading messages from transport."""
if self._tg is None:
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
try:
async for message in self.transport.read_messages():
if self._closed:
break
msg_type = message.get("type")
# Route control messages
if msg_type == "control_response":
response = message.get("response", {})
request_id = response.get("request_id")
if request_id in self.pending_control_responses:
event = self.pending_control_responses[request_id]
if response.get("subtype") == "error":
self.pending_control_results[request_id] = Exception(
response.get("error", "Unknown error")
)
else:
self.pending_control_results[request_id] = response
event.set()
continue
elif msg_type == "control_request":
# Handle incoming control requests from CLI
# Cast message to SDKControlRequest for type safety
request: SDKControlRequest = message # type: ignore[assignment]
if self._tg:
self._tg.start_soon(self._handle_control_request, request)
continue
elif msg_type == "control_cancel_request":
# Handle cancel requests
# TODO: Implement cancellation support
continue
# Regular SDK messages go to the stream
await self._message_send.send(message)
except anyio.get_cancelled_exc_class():
# Task was cancelled - this is expected behavior
logger.debug("Read task cancelled")
raise # Re-raise to properly handle cancellation
except Exception as e:
logger.error(f"Fatal error in message reader: {e}")
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
finally:
# Always signal end of stream
await self._message_send.send({"type": "end"})
async def _handle_control_request(self, request: SDKControlRequest) -> None:
"""Handle incoming control request from CLI."""
request_id = request["request_id"]
request_data = request["request"]
subtype = request_data["subtype"]
try:
response_data = {}
if subtype == "can_use_tool":
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
# Handle tool permission request
if not self.can_use_tool:
raise Exception("canUseTool callback is not provided")
response_data = await self.can_use_tool(
permission_request["tool_name"],
permission_request["input"],
{
"signal": None, # TODO: Add abort signal support
"suggestions": permission_request.get("permission_suggestions"),
},
)
elif subtype == "hook_callback":
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
# Handle hook callback
callback_id = hook_callback_request["callback_id"]
callback = self.hook_callbacks.get(callback_id)
if not callback:
raise Exception(f"No hook callback found for ID: {callback_id}")
response_data = await callback(
request_data.get("input"),
request_data.get("tool_use_id"),
{"signal": None}, # TODO: Add abort signal support
)
else:
raise Exception(f"Unsupported control request subtype: {subtype}")
# Send success response
success_response: SDKControlResponse = {
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": response_data,
},
}
await self.transport.write(json.dumps(success_response) + "\n")
except Exception as e:
# Send error response
error_response: SDKControlResponse = {
"type": "control_response",
"response": {
"subtype": "error",
"request_id": request_id,
"error": str(e),
},
}
await self.transport.write(json.dumps(error_response) + "\n")
async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send control request to CLI and wait for response."""
if not self.is_streaming_mode:
raise Exception("Control requests require streaming mode")
# Generate unique request ID
self._request_counter += 1
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
# Create event for response
event = anyio.Event()
self.pending_control_responses[request_id] = event
# Build and send request
control_request = {
"type": "control_request",
"request_id": request_id,
"request": request,
}
await self.transport.write(json.dumps(control_request) + "\n")
# Wait for response
try:
with anyio.fail_after(60.0):
await event.wait()
result = self.pending_control_results.pop(request_id)
self.pending_control_responses.pop(request_id, None)
if isinstance(result, Exception):
raise result
response_data = result.get("response", {})
return response_data if isinstance(response_data, dict) else {}
except TimeoutError as e:
self.pending_control_responses.pop(request_id, None)
self.pending_control_results.pop(request_id, None)
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
async def interrupt(self) -> None:
"""Send interrupt control request."""
await self._send_control_request({"subtype": "interrupt"})
async def set_permission_mode(self, mode: str) -> None:
"""Change permission mode."""
await self._send_control_request(
{
"subtype": "set_permission_mode",
"mode": mode,
}
)
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
"""Stream input messages to transport."""
try:
async for message in stream:
if self._closed:
break
await self.transport.write(json.dumps(message) + "\n")
# After all messages sent, end input
await self.transport.end_input()
except Exception as e:
logger.debug(f"Error streaming input: {e}")
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Receive SDK messages (not control messages)."""
async with self._message_receive:
async for message in self._message_receive:
# Check for special messages
if message.get("type") == "end":
break
elif message.get("type") == "error":
raise Exception(message.get("error", "Unknown error"))
yield message
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._tg:
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
await self.transport.close()
# Make Query an async iterator
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
"""Return async iterator for messages."""
return self.receive_messages()
async def __anext__(self) -> dict[str, Any]:
"""Get next message."""
async for message in self.receive_messages():
return message
raise StopAsyncIteration

View File

@@ -12,33 +12,56 @@ class Transport(ABC):
(e.g., remote Claude Code connections). The Claude Code team may change or
or remove this abstract class in any future release. Custom implementations
must be updated to match interface changes.
This is a low-level transport interface that handles raw I/O with the Claude
process or service. The Query class builds on top of this to implement the
control protocol and message routing.
"""
@abstractmethod
async def connect(self) -> None:
"""Initialize connection."""
"""Connect the transport and prepare for communication.
For subprocess transports, this starts the process.
For network transports, this establishes the connection.
"""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Close connection."""
async def write(self, data: str) -> None:
"""Write raw data to the transport.
Args:
data: Raw string data to write (typically JSON + newline)
"""
pass
@abstractmethod
async def send_request(
self, messages: list[dict[str, Any]], options: dict[str, Any]
) -> None:
"""Send request to Claude."""
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport.
Yields:
Parsed JSON messages from the transport
"""
pass
@abstractmethod
def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Receive messages from Claude."""
async def close(self) -> None:
"""Close the transport connection and clean up resources."""
pass
@abstractmethod
def is_connected(self) -> bool:
"""Check if transport is connected."""
def is_ready(self) -> bool:
"""Check if transport is ready for communication.
Returns:
True if transport is ready to send/receive messages
"""
pass
@abstractmethod
async def end_input(self) -> None:
"""End the input stream (close stdin for process transports)."""
pass

View File

@@ -7,6 +7,7 @@ import shutil
import tempfile
from collections import deque
from collections.abc import AsyncIterable, AsyncIterator
from contextlib import suppress
from pathlib import Path
from subprocess import PIPE
from typing import Any
@@ -33,7 +34,6 @@ class SubprocessCLITransport(Transport):
prompt: str | AsyncIterable[dict[str, Any]],
options: ClaudeCodeOptions,
cli_path: str | Path | None = None,
close_stdin_after_prompt: bool = False,
):
self._prompt = prompt
self._is_streaming = not isinstance(prompt, str)
@@ -44,11 +44,8 @@ class SubprocessCLITransport(Transport):
self._stdout_stream: TextReceiveStream | None = None
self._stderr_stream: TextReceiveStream | None = None
self._stdin_stream: TextSendStream | None = None
self._pending_control_responses: dict[str, dict[str, Any]] = {}
self._request_counter = 0
self._close_stdin_after_prompt = close_stdin_after_prompt
self._task_group: anyio.abc.TaskGroup | None = None
self._stderr_file: Any = None # tempfile.NamedTemporaryFile
self._ready = False
def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
@@ -174,7 +171,6 @@ class SubprocessCLITransport(Transport):
mode="w+", prefix="claude_stderr_", suffix=".log", delete=False
)
# Enable stdin pipe for both modes (but we'll close it for string mode)
# Merge environment variables: system -> user -> SDK required
process_env = {
**os.environ,
@@ -197,19 +193,14 @@ class SubprocessCLITransport(Transport):
if self._process.stdout:
self._stdout_stream = TextReceiveStream(self._process.stdout)
# Handle stdin based on mode
if self._is_streaming:
# Streaming mode: keep stdin open and start streaming task
if self._process.stdin:
self._stdin_stream = TextSendStream(self._process.stdin)
# Start streaming messages to stdin in background
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._stream_to_stdin)
else:
# String mode: close stdin immediately (backward compatible)
if self._process.stdin:
await self._process.stdin.aclose()
# Setup stdin for streaming mode
if self._is_streaming and self._process.stdin:
self._stdin_stream = TextSendStream(self._process.stdin)
elif not self._is_streaming and self._process.stdin:
# String mode: close stdin immediately
await self._process.stdin.aclose()
self._ready = True
except FileNotFoundError as e:
# Check if the error comes from the working directory or the CLI
@@ -221,27 +212,31 @@ class SubprocessCLITransport(Transport):
except Exception as e:
raise CLIConnectionError(f"Failed to start Claude Code: {e}") from e
async def disconnect(self) -> None:
"""Terminate subprocess."""
async def close(self) -> None:
"""Close the transport and clean up resources."""
self._ready = False
if not self._process:
return
# Cancel task group if it exists
if self._task_group:
self._task_group.cancel_scope.cancel()
await self._task_group.__aexit__(None, None, None)
self._task_group = None
# Close stdin first if it's still open
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
if self._process.stdin:
with suppress(Exception):
await self._process.stdin.aclose()
# Terminate and wait for process
if self._process.returncode is None:
try:
with suppress(ProcessLookupError):
self._process.terminate()
with anyio.fail_after(5.0):
# Wait for process to finish with timeout
with suppress(Exception):
# Just try to wait, but don't block if it fails
await self._process.wait()
except TimeoutError:
self._process.kill()
await self._process.wait()
except ProcessLookupError:
pass
# Clean up temp file
if self._stderr_file:
@@ -257,57 +252,35 @@ class SubprocessCLITransport(Transport):
self._stderr_stream = None
self._stdin_stream = None
async def send_request(self, messages: list[Any], options: dict[str, Any]) -> None:
"""Send additional messages in streaming mode."""
if not self._is_streaming:
raise CLIConnectionError("send_request only works in streaming mode")
async def write(self, data: str) -> None:
"""Write raw data to the transport."""
if not self._stdin_stream:
raise CLIConnectionError("stdin not available - stream may have ended")
raise CLIConnectionError("Cannot write: stdin not available")
# Send each message as a user message
for message in messages:
# Ensure message has required structure
if not isinstance(message, dict):
message = {
"type": "user",
"message": {"role": "user", "content": str(message)},
"parent_tool_use_id": None,
"session_id": options.get("session_id", "default"),
}
await self._stdin_stream.send(data)
await self._stdin_stream.send(json.dumps(message) + "\n")
async def _stream_to_stdin(self) -> None:
"""Stream messages to stdin for streaming mode."""
if not self._stdin_stream or not isinstance(self._prompt, AsyncIterable):
return
try:
async for message in self._prompt:
if not self._stdin_stream:
break
await self._stdin_stream.send(json.dumps(message) + "\n")
# Close stdin after prompt if requested (e.g., for query() one-shot mode)
if self._close_stdin_after_prompt and self._stdin_stream:
async def end_input(self) -> None:
"""End the input stream (close stdin)."""
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
# Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode)
except Exception as e:
logger.debug(f"Error streaming to stdin: {e}")
if self._stdin_stream:
await self._stdin_stream.aclose()
self._stdin_stream = None
self._stdin_stream = None
if self._process and self._process.stdin:
with suppress(Exception):
await self._process.stdin.aclose()
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Receive messages from CLI."""
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport."""
return self._read_messages_impl()
async def _read_messages_impl(self) -> AsyncIterator[dict[str, Any]]:
"""Internal implementation of read_messages."""
if not self._process or not self._stdout_stream:
raise CLIConnectionError("Not connected")
json_buffer = ""
# Process stdout messages first
# Process stdout messages
try:
async for line in self._stdout_stream:
line_str = line.strip()
@@ -336,20 +309,7 @@ class SubprocessCLITransport(Transport):
try:
data = json.loads(json_buffer)
json_buffer = ""
# Handle control responses separately
if data.get("type") == "control_response":
response = data.get("response", {})
request_id = response.get("request_id")
if request_id:
# Store the response for the pending request
self._pending_control_responses[request_id] = response
continue
try:
yield data
except GeneratorExit:
return
yield data
except json.JSONDecodeError:
# We are speculatively decoding the buffer until we get
# a full JSON object. If there is an actual issue, we
@@ -359,7 +319,7 @@ class SubprocessCLITransport(Transport):
except anyio.ClosedResourceError:
pass
except GeneratorExit:
# Client disconnected - still need to clean up
# Client disconnected
pass
# Read stderr from temp file (keep only last N lines for memory efficiency)
@@ -402,48 +362,12 @@ class SubprocessCLITransport(Transport):
# Log stderr for debugging but don't fail on non-zero exit
logger.debug(f"Process stderr: {stderr_output}")
def is_connected(self) -> bool:
"""Check if subprocess is running."""
return self._process is not None and self._process.returncode is None
def is_ready(self) -> bool:
"""Check if transport is ready for communication."""
return (
self._ready
and self._process is not None
and self._process.returncode is None
)
async def interrupt(self) -> None:
"""Send interrupt control request (only works in streaming mode)."""
if not self._is_streaming:
raise CLIConnectionError(
"Interrupt requires streaming mode (AsyncIterable prompt)"
)
if not self._stdin_stream:
raise CLIConnectionError("Not connected or stdin not available")
await self._send_control_request({"subtype": "interrupt"})
async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send a control request and wait for response."""
if not self._stdin_stream:
raise CLIConnectionError("Stdin not available")
# Generate unique request ID
self._request_counter += 1
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
# Build control request
control_request = {
"type": "control_request",
"request_id": request_id,
"request": request,
}
# Send request
await self._stdin_stream.send(json.dumps(control_request) + "\n")
# Wait for response
while request_id not in self._pending_control_responses:
await anyio.sleep(0.1)
response = self._pending_control_responses.pop(request_id)
if response.get("subtype") == "error":
raise CLIConnectionError(f"Control request failed: {response.get('error')}")
return response
# Remove interrupt and control request methods - these now belong in Query class

View File

@@ -1,5 +1,6 @@
"""Claude SDK Client for interacting with Claude Code."""
import json
import os
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
@@ -96,12 +97,15 @@ class ClaudeSDKClient:
options = ClaudeCodeOptions()
self.options = options
self._transport: Any | None = None
self._query: Any | None = None
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"
async def connect(
self, prompt: str | AsyncIterable[dict[str, Any]] | None = None
) -> None:
"""Connect to Claude with a prompt or message stream."""
from ._internal.query import Query
from ._internal.transport.subprocess_cli import SubprocessCLITransport
# Auto-connect with empty async iterable if no prompt is provided
@@ -112,20 +116,38 @@ class ClaudeSDKClient:
return
yield {} # type: ignore[unreachable]
actual_prompt = _empty_stream() if prompt is None else prompt
self._transport = SubprocessCLITransport(
prompt=_empty_stream() if prompt is None else prompt,
prompt=actual_prompt,
options=self.options,
)
await self._transport.connect()
# Create Query to handle control protocol
self._query = Query(
transport=self._transport,
is_streaming_mode=True, # ClaudeSDKClient always uses streaming mode
can_use_tool=None, # TODO: Add support for can_use_tool callback
hooks=None, # TODO: Add support for hooks
)
# Start reading messages and initialize
await self._query.start()
await self._query.initialize()
# If we have an initial prompt stream, start streaming it
if prompt is not None and isinstance(prompt, AsyncIterable) and self._query._tg:
self._query._tg.start_soon(self._query.stream_input, prompt)
async def receive_messages(self) -> AsyncIterator[Message]:
"""Receive all messages from Claude."""
if not self._transport:
if not self._query:
raise CLIConnectionError("Not connected. Call connect() first.")
from ._internal.message_parser import parse_message
async for data in self._transport.receive_messages():
async for data in self._query.receive_messages():
yield parse_message(data)
async def query(
@@ -138,7 +160,7 @@ class ClaudeSDKClient:
prompt: Either a string message or an async iterable of message dictionaries
session_id: Session identifier for the conversation
"""
if not self._transport:
if not self._query or not self._transport:
raise CLIConnectionError("Not connected. Call connect() first.")
# Handle string prompts
@@ -149,24 +171,45 @@ class ClaudeSDKClient:
"parent_tool_use_id": None,
"session_id": session_id,
}
await self._transport.send_request([message], {"session_id": session_id})
await self._transport.write(json.dumps(message) + "\n")
else:
# Handle AsyncIterable prompts
messages = []
# Handle AsyncIterable prompts - stream them
async for msg in prompt:
# Ensure session_id is set on each message
if "session_id" not in msg:
msg["session_id"] = session_id
messages.append(msg)
if messages:
await self._transport.send_request(messages, {"session_id": session_id})
await self._transport.write(json.dumps(msg) + "\n")
async def interrupt(self) -> None:
"""Send interrupt signal (only works with streaming mode)."""
if not self._transport:
if not self._query:
raise CLIConnectionError("Not connected. Call connect() first.")
await self._transport.interrupt()
await self._query.interrupt()
async def get_server_info(self) -> dict[str, Any] | None:
"""Get server initialization info including available commands and output styles.
Returns initialization information from the Claude Code server including:
- Available commands (slash commands, system commands, etc.)
- Current and available output styles
- Server capabilities
Returns:
Dictionary with server info, or None if not in streaming mode
Example:
```python
async with ClaudeSDKClient() as client:
info = await client.get_server_info()
if info:
print(f"Commands available: {len(info.get('commands', []))}")
print(f"Output style: {info.get('output_style', 'default')}")
```
"""
if not self._query:
raise CLIConnectionError("Not connected. Call connect() first.")
# Return the initialization result that was already obtained during connect
return getattr(self._query, "_initialization_result", None)
async def receive_response(self) -> AsyncIterator[Message]:
"""
@@ -211,9 +254,10 @@ class ClaudeSDKClient:
async def disconnect(self) -> None:
"""Disconnect from Claude."""
if self._transport:
await self._transport.disconnect()
self._transport = None
if self._query:
await self._query.close()
self._query = None
self._transport = None
async def __aenter__(self) -> "ClaudeSDKClient":
"""Enter async context - automatically connects with empty stream for interactive use."""

View File

@@ -141,3 +141,72 @@ class ClaudeCodeOptions:
extra_args: dict[str, str | None] = field(
default_factory=dict
) # Pass arbitrary CLI flags
# SDK Control Protocol
class SDKControlInterruptRequest(TypedDict):
subtype: Literal["interrupt"]
class SDKControlPermissionRequest(TypedDict):
subtype: Literal["can_use_tool"]
tool_name: str
input: dict[str, Any]
# TODO: Add PermissionUpdate type here
permission_suggestions: list[Any] | None
blocked_path: str | None
class SDKControlInitializeRequest(TypedDict):
subtype: Literal["initialize"]
# TODO: Use HookEvent names as the key.
hooks: dict[str, Any] | None
class SDKControlSetPermissionModeRequest(TypedDict):
subtype: Literal["set_permission_mode"]
# TODO: Add PermissionMode
mode: str
class SDKHookCallbackRequest(TypedDict):
subtype: Literal["hook_callback"]
callback_id: str
input: Any
tool_use_id: str | None
class SDKControlMcpMessageRequest(TypedDict):
subtype: Literal["mcp_message"]
server_name: str
message: Any
class SDKControlRequest(TypedDict):
type: Literal["control_request"]
request_id: str
request: (
SDKControlInterruptRequest
| SDKControlPermissionRequest
| SDKControlInitializeRequest
| SDKControlSetPermissionModeRequest
| SDKHookCallbackRequest
| SDKControlMcpMessageRequest
)
class ControlResponse(TypedDict):
subtype: Literal["success"]
request_id: str
response: dict[str, Any] | None
class ControlErrorResponse(TypedDict):
subtype: Literal["error"]
request_id: str
error: str
class SDKControlResponse(TypedDict):
type: Literal["control_response"]
response: ControlResponse | ControlErrorResponse

View File

@@ -1,6 +1,6 @@
"""Tests for Claude SDK client functionality."""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import anyio
@@ -102,9 +102,12 @@ class TestQueryFunction:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
options = ClaudeCodeOptions(cwd="/custom/path")
messages = []

View File

@@ -3,7 +3,7 @@
These tests verify end-to-end functionality with mocked CLI responses.
"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import anyio
import pytest
@@ -52,9 +52,12 @@ class TestIntegration:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Run query
messages = []
@@ -118,9 +121,12 @@ class TestIntegration:
"total_cost_usd": 0.002,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Run query with tools enabled
messages = []
@@ -185,9 +191,12 @@ class TestIntegration:
},
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Run query with continuation
messages = []

View File

@@ -1,10 +1,11 @@
"""Tests for ClaudeSDKClient streaming functionality and query() with async iterables."""
import asyncio
import json
import sys
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import anyio
import pytest
@@ -22,6 +23,90 @@ from claude_code_sdk import (
from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport
def create_mock_transport(with_init_response=True):
"""Create a properly configured mock transport.
Args:
with_init_response: If True, automatically respond to initialization request
"""
mock_transport = AsyncMock()
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Track written messages to simulate control protocol responses
written_messages = []
async def mock_write(data):
written_messages.append(data)
mock_transport.write.side_effect = mock_write
# Default read_messages to handle control protocol
async def control_protocol_generator():
# Wait for initialization request if needed
if with_init_response:
# Wait a bit for the write to happen
await asyncio.sleep(0.01)
# Check if initialization was requested
for msg_str in written_messages:
try:
msg = json.loads(msg_str.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "initialize"
):
# Send initialization response
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Keep checking for other control requests (like interrupt)
last_check = len(written_messages)
timeout_counter = 0
while timeout_counter < 100: # Avoid infinite loop
await asyncio.sleep(0.01)
timeout_counter += 1
# Check for new messages
for msg_str in written_messages[last_check:]:
try:
msg = json.loads(msg_str.strip())
if msg.get("type") == "control_request":
subtype = msg.get("request", {}).get("subtype")
if subtype == "interrupt":
# Send interrupt response
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
},
}
return # End after interrupt
except (json.JSONDecodeError, KeyError, AttributeError):
pass
last_check = len(written_messages)
# Then end the stream
return
mock_transport.read_messages = control_protocol_generator
return mock_transport
class TestClaudeSDKClientStreaming:
"""Test ClaudeSDKClient streaming functionality."""
@@ -32,7 +117,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
@@ -41,7 +126,7 @@ class TestClaudeSDKClientStreaming:
assert client._transport is mock_transport
# Verify disconnect was called on exit
mock_transport.disconnect.assert_called_once()
mock_transport.close.assert_called_once()
anyio.run(_test)
@@ -52,7 +137,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
@@ -64,7 +149,7 @@ class TestClaudeSDKClientStreaming:
await client.disconnect()
# Verify disconnect was called
mock_transport.disconnect.assert_called_once()
mock_transport.close.assert_called_once()
assert client._transport is None
anyio.run(_test)
@@ -76,7 +161,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
@@ -95,7 +180,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async def message_stream():
@@ -123,20 +208,30 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
await client.query("Test message")
# Verify send_request was called with correct format
mock_transport.send_request.assert_called_once()
call_args = mock_transport.send_request.call_args
messages, options = call_args[0]
assert len(messages) == 1
assert messages[0]["type"] == "user"
assert messages[0]["message"]["content"] == "Test message"
assert options["session_id"] == "default"
# Verify write was called with correct format
# Should have at least 2 writes: init request and user message
assert mock_transport.write.call_count >= 2
# Find the user message in the write calls
user_msg_found = False
for call in mock_transport.write.call_args_list:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "user":
assert msg["message"]["content"] == "Test message"
assert msg["session_id"] == "default"
user_msg_found = True
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
assert user_msg_found, "User message not found in write calls"
anyio.run(_test)
@@ -147,16 +242,25 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
await client.query("Test", session_id="custom-session")
call_args = mock_transport.send_request.call_args
messages, options = call_args[0]
assert messages[0]["session_id"] == "custom-session"
assert options["session_id"] == "custom-session"
# Find the user message with custom session ID
session_found = False
for call in mock_transport.write.call_args_list:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "user":
assert msg["session_id"] == "custom-session"
session_found = True
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
assert session_found, "User message with custom session not found"
anyio.run(_test)
@@ -177,11 +281,37 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
# Mock the message stream with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
yield {
"type": "assistant",
"message": {
@@ -195,7 +325,7 @@ class TestClaudeSDKClientStreaming:
"message": {"role": "user", "content": "Hi there"},
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
messages = []
@@ -220,11 +350,37 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
# Mock the message stream with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
yield {
"type": "assistant",
"message": {
@@ -255,7 +411,7 @@ class TestClaudeSDKClientStreaming:
"model": "claude-opus-4-1-20250805",
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
messages = []
@@ -276,12 +432,28 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
# Interrupt is now handled via control protocol
await client.interrupt()
mock_transport.interrupt.assert_called_once()
# Check that a control request was sent via write
write_calls = mock_transport.write.call_args_list
interrupt_found = False
for call in write_calls:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "interrupt"
):
interrupt_found = True
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
assert interrupt_found, "Interrupt control request not found"
anyio.run(_test)
@@ -308,7 +480,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient(options=options)
@@ -327,11 +499,38 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock receive to wait then yield messages
# Mock receive to wait then yield messages with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
if call:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
await asyncio.sleep(0.1)
yield {
"type": "assistant",
@@ -353,7 +552,7 @@ class TestClaudeSDKClientStreaming:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
# Helper to get next message
@@ -397,9 +596,35 @@ while True:
line = sys.stdin.readline()
if not line:
break
stdin_messages.append(line.strip())
# Verify we got 2 messages
try:
msg = json.loads(line.strip())
# Handle control requests
if msg.get("type") == "control_request":
request_id = msg.get("request_id")
request = msg.get("request", {})
# Send control response for initialize
if request.get("subtype") == "initialize":
response = {
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": {
"commands": [],
"output_style": "default"
}
}
}
print(json.dumps(response))
sys.stdout.flush()
else:
stdin_messages.append(line.strip())
except:
stdin_messages.append(line.strip())
# Verify we got 2 user messages
assert len(stdin_messages) == 2
assert '"First"' in stdin_messages[0]
assert '"Second"' in stdin_messages[1]
@@ -476,8 +701,11 @@ class TestClaudeSDKClientEdgeCases:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport_class.return_value = mock_transport
# Create a new mock transport for each call
mock_transport_class.side_effect = [
create_mock_transport(),
create_mock_transport(),
]
client = ClaudeSDKClient()
await client.connect()
@@ -506,7 +734,7 @@ class TestClaudeSDKClientEdgeCases:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
with pytest.raises(ValueError):
@@ -514,7 +742,7 @@ class TestClaudeSDKClientEdgeCases:
raise ValueError("Test error")
# Disconnect should still be called
mock_transport.disconnect.assert_called_once()
mock_transport.close.assert_called_once()
anyio.run(_test)
@@ -525,11 +753,38 @@ class TestClaudeSDKClientEdgeCases:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
# Mock the message stream with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
if call:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
yield {
"type": "assistant",
"message": {
@@ -557,7 +812,7 @@ class TestClaudeSDKClientEdgeCases:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
# Test list comprehension pattern from docstring

View File

@@ -63,7 +63,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([]) # type: ignore[assignment]
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 2
@@ -97,7 +97,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 2
@@ -127,7 +127,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 2
@@ -173,7 +173,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 1
@@ -221,7 +221,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 1
@@ -252,7 +252,7 @@ class TestSubprocessBuffering:
with pytest.raises(Exception) as exc_info:
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert isinstance(exc_info.value, CLIJSONDecodeError)
@@ -293,7 +293,7 @@ class TestSubprocessBuffering:
transport._stderr_stream = MockTextReceiveStream([])
messages: list[Any] = []
async for msg in transport.receive_messages():
async for msg in transport.read_messages():
messages.append(msg)
assert len(messages) == 3

View File

@@ -112,8 +112,8 @@ class TestSubprocessCLITransport:
assert "--resume" in cmd
assert "session-123" in cmd
def test_connect_disconnect(self):
"""Test connect and disconnect lifecycle."""
def test_connect_close(self):
"""Test connect and close lifecycle."""
async def _test():
with patch("anyio.open_process") as mock_exec:
@@ -139,22 +139,22 @@ class TestSubprocessCLITransport:
await transport.connect()
assert transport._process is not None
assert transport.is_connected()
assert transport.is_ready()
await transport.disconnect()
await transport.close()
mock_process.terminate.assert_called_once()
anyio.run(_test)
def test_receive_messages(self):
"""Test parsing messages from CLI output."""
# This test is simplified to just test the parsing logic
def test_read_messages(self):
"""Test reading messages from CLI output."""
# This test is simplified to just test the transport creation
# The full async stream handling is tested in integration tests
transport = SubprocessCLITransport(
prompt="test", options=ClaudeCodeOptions(), cli_path="/usr/bin/claude"
)
# The actual message parsing is done by the client, not the transport
# The transport now just provides raw message reading via read_messages()
# So we just verify the transport can be created and basic structure is correct
assert transport._prompt == "test"
assert transport._cli_path == "/usr/bin/claude"