This commit is contained in:
Dickson Tsai
2025-07-19 15:25:34 -07:00
parent c95c077b9b
commit e65c2f417a

View File

@@ -3,7 +3,6 @@
import asyncio
import sys
import tempfile
import textwrap
from pathlib import Path
from unittest.mock import AsyncMock, patch
@@ -373,80 +372,6 @@ class TestClaudeSDKClientStreaming:
class TestQueryWithAsyncIterable:
"""Test query() function with async iterable inputs."""
def _create_test_script(
self, expected_messages=None, response=None, should_error=False
):
"""Create a test script that validates CLI args and stdin messages.
Args:
expected_messages: List of expected message content strings, or None to skip validation
response: Custom response to output, defaults to a success result
should_error: If True, script will exit with error after reading stdin
Returns:
Path to the test script
"""
if response is None:
response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}'
script_content = textwrap.dedent(
"""
#!/usr/bin/env python3
import sys
import json
import time
# Check command line args
args = sys.argv[1:]
assert "--output-format" in args
assert "stream-json" in args
# Read stdin messages
stdin_messages = []
stdin_closed = False
try:
while True:
line = sys.stdin.readline()
if not line:
stdin_closed = True
break
stdin_messages.append(line.strip())
except:
stdin_closed = True
""",
)
if expected_messages is not None:
script_content += textwrap.dedent(
f"""
# Verify we got the expected messages
assert len(stdin_messages) == {len(expected_messages)}
""",
)
for i, msg in enumerate(expected_messages):
script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n'''
if should_error:
script_content += textwrap.dedent(
"""
sys.exit(1)
""",
)
else:
script_content += textwrap.dedent(
f"""
# Output response
print('{response}')
""",
)
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
test_script = f.name
f.write(script_content)
Path(test_script).chmod(0o755)
return test_script
def test_query_with_async_iterable(self):
"""Test query with async iterable of messages."""
@@ -455,32 +380,63 @@ class TestQueryWithAsyncIterable:
yield {"type": "user", "message": {"role": "user", "content": "First"}}
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
test_script = self._create_test_script(
expected_messages=["First", "Second"]
)
# Create a simple test script that validates stdin and outputs a result
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
test_script = f.name
f.write("""#!/usr/bin/env python3
import sys
import json
# Read stdin messages
stdin_messages = []
while True:
line = sys.stdin.readline()
if not line:
break
stdin_messages.append(line.strip())
# Verify we got 2 messages
assert len(stdin_messages) == 2
assert '"First"' in stdin_messages[0]
assert '"Second"' in stdin_messages[1]
# Output a valid result
print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}')
""")
Path(test_script).chmod(0o755)
try:
# Mock _build_command to return our test script
# Mock _find_cli to return python executing our test script
with patch.object(
SubprocessCLITransport,
"_build_command",
return_value=[
sys.executable,
test_script,
"--output-format",
"stream-json",
"--verbose",
],
"_find_cli",
return_value=sys.executable
):
# Run query with async iterable
messages = []
async for msg in query(prompt=message_stream()):
messages.append(msg)
# Mock _build_command to add our test script as first argument
original_build_command = SubprocessCLITransport._build_command
def mock_build_command(self):
# Get original command
cmd = original_build_command(self)
# Replace the CLI path with python + script
cmd[0] = test_script
return cmd
with patch.object(
SubprocessCLITransport,
"_build_command",
mock_build_command
):
# Run query with async iterable
messages = []
async for msg in query(prompt=message_stream()):
messages.append(msg)
# Should get the result message
assert len(messages) == 1
assert isinstance(messages[0], ResultMessage)
assert messages[0].subtype == "success"
# Should get the result message
assert len(messages) == 1
assert isinstance(messages[0], ResultMessage)
assert messages[0].subtype == "success"
finally:
# Clean up
Path(test_script).unlink()