This commit is contained in:
Dickson Tsai
2025-07-19 15:25:34 -07:00
parent c95c077b9b
commit e65c2f417a

View File

@@ -3,7 +3,6 @@
import asyncio import asyncio
import sys import sys
import tempfile import tempfile
import textwrap
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@@ -373,80 +372,6 @@ class TestClaudeSDKClientStreaming:
class TestQueryWithAsyncIterable: class TestQueryWithAsyncIterable:
"""Test query() function with async iterable inputs.""" """Test query() function with async iterable inputs."""
def _create_test_script(
self, expected_messages=None, response=None, should_error=False
):
"""Create a test script that validates CLI args and stdin messages.
Args:
expected_messages: List of expected message content strings, or None to skip validation
response: Custom response to output, defaults to a success result
should_error: If True, script will exit with error after reading stdin
Returns:
Path to the test script
"""
if response is None:
response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}'
script_content = textwrap.dedent(
"""
#!/usr/bin/env python3
import sys
import json
import time
# Check command line args
args = sys.argv[1:]
assert "--output-format" in args
assert "stream-json" in args
# Read stdin messages
stdin_messages = []
stdin_closed = False
try:
while True:
line = sys.stdin.readline()
if not line:
stdin_closed = True
break
stdin_messages.append(line.strip())
except:
stdin_closed = True
""",
)
if expected_messages is not None:
script_content += textwrap.dedent(
f"""
# Verify we got the expected messages
assert len(stdin_messages) == {len(expected_messages)}
""",
)
for i, msg in enumerate(expected_messages):
script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n'''
if should_error:
script_content += textwrap.dedent(
"""
sys.exit(1)
""",
)
else:
script_content += textwrap.dedent(
f"""
# Output response
print('{response}')
""",
)
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
test_script = f.name
f.write(script_content)
Path(test_script).chmod(0o755)
return test_script
def test_query_with_async_iterable(self): def test_query_with_async_iterable(self):
"""Test query with async iterable of messages.""" """Test query with async iterable of messages."""
@@ -455,32 +380,63 @@ class TestQueryWithAsyncIterable:
yield {"type": "user", "message": {"role": "user", "content": "First"}} yield {"type": "user", "message": {"role": "user", "content": "First"}}
yield {"type": "user", "message": {"role": "user", "content": "Second"}} yield {"type": "user", "message": {"role": "user", "content": "Second"}}
test_script = self._create_test_script( # Create a simple test script that validates stdin and outputs a result
expected_messages=["First", "Second"] with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
) test_script = f.name
f.write("""#!/usr/bin/env python3
import sys
import json
# Read stdin messages
stdin_messages = []
while True:
line = sys.stdin.readline()
if not line:
break
stdin_messages.append(line.strip())
# Verify we got 2 messages
assert len(stdin_messages) == 2
assert '"First"' in stdin_messages[0]
assert '"Second"' in stdin_messages[1]
# Output a valid result
print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}')
""")
Path(test_script).chmod(0o755)
try: try:
# Mock _build_command to return our test script # Mock _find_cli to return python executing our test script
with patch.object( with patch.object(
SubprocessCLITransport, SubprocessCLITransport,
"_build_command", "_find_cli",
return_value=[ return_value=sys.executable
sys.executable,
test_script,
"--output-format",
"stream-json",
"--verbose",
],
): ):
# Run query with async iterable # Mock _build_command to add our test script as first argument
messages = [] original_build_command = SubprocessCLITransport._build_command
async for msg in query(prompt=message_stream()):
messages.append(msg)
# Should get the result message def mock_build_command(self):
assert len(messages) == 1 # Get original command
assert isinstance(messages[0], ResultMessage) cmd = original_build_command(self)
assert messages[0].subtype == "success" # Replace the CLI path with python + script
cmd[0] = test_script
return cmd
with patch.object(
SubprocessCLITransport,
"_build_command",
mock_build_command
):
# Run query with async iterable
messages = []
async for msg in query(prompt=message_stream()):
messages.append(msg)
# Should get the result message
assert len(messages) == 1
assert isinstance(messages[0], ResultMessage)
assert messages[0].subtype == "success"
finally: finally:
# Clean up # Clean up
Path(test_script).unlink() Path(test_script).unlink()