mirror of
https://github.com/anthropics/claude-agent-sdk-python.git
synced 2025-10-06 01:00:03 +03:00
Fix test
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -373,80 +372,6 @@ class TestClaudeSDKClientStreaming:
|
||||
class TestQueryWithAsyncIterable:
|
||||
"""Test query() function with async iterable inputs."""
|
||||
|
||||
def _create_test_script(
|
||||
self, expected_messages=None, response=None, should_error=False
|
||||
):
|
||||
"""Create a test script that validates CLI args and stdin messages.
|
||||
|
||||
Args:
|
||||
expected_messages: List of expected message content strings, or None to skip validation
|
||||
response: Custom response to output, defaults to a success result
|
||||
should_error: If True, script will exit with error after reading stdin
|
||||
|
||||
Returns:
|
||||
Path to the test script
|
||||
"""
|
||||
if response is None:
|
||||
response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}'
|
||||
|
||||
script_content = textwrap.dedent(
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
|
||||
# Check command line args
|
||||
args = sys.argv[1:]
|
||||
assert "--output-format" in args
|
||||
assert "stream-json" in args
|
||||
|
||||
# Read stdin messages
|
||||
stdin_messages = []
|
||||
stdin_closed = False
|
||||
try:
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
stdin_closed = True
|
||||
break
|
||||
stdin_messages.append(line.strip())
|
||||
except:
|
||||
stdin_closed = True
|
||||
""",
|
||||
)
|
||||
|
||||
if expected_messages is not None:
|
||||
script_content += textwrap.dedent(
|
||||
f"""
|
||||
# Verify we got the expected messages
|
||||
assert len(stdin_messages) == {len(expected_messages)}
|
||||
""",
|
||||
)
|
||||
for i, msg in enumerate(expected_messages):
|
||||
script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n'''
|
||||
|
||||
if should_error:
|
||||
script_content += textwrap.dedent(
|
||||
"""
|
||||
sys.exit(1)
|
||||
""",
|
||||
)
|
||||
else:
|
||||
script_content += textwrap.dedent(
|
||||
f"""
|
||||
# Output response
|
||||
print('{response}')
|
||||
""",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
test_script = f.name
|
||||
f.write(script_content)
|
||||
|
||||
Path(test_script).chmod(0o755)
|
||||
return test_script
|
||||
|
||||
def test_query_with_async_iterable(self):
|
||||
"""Test query with async iterable of messages."""
|
||||
|
||||
@@ -455,32 +380,63 @@ class TestQueryWithAsyncIterable:
|
||||
yield {"type": "user", "message": {"role": "user", "content": "First"}}
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
|
||||
|
||||
test_script = self._create_test_script(
|
||||
expected_messages=["First", "Second"]
|
||||
)
|
||||
# Create a simple test script that validates stdin and outputs a result
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
test_script = f.name
|
||||
f.write("""#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Read stdin messages
|
||||
stdin_messages = []
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
stdin_messages.append(line.strip())
|
||||
|
||||
# Verify we got 2 messages
|
||||
assert len(stdin_messages) == 2
|
||||
assert '"First"' in stdin_messages[0]
|
||||
assert '"Second"' in stdin_messages[1]
|
||||
|
||||
# Output a valid result
|
||||
print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}')
|
||||
""")
|
||||
|
||||
Path(test_script).chmod(0o755)
|
||||
|
||||
try:
|
||||
# Mock _build_command to return our test script
|
||||
# Mock _find_cli to return python executing our test script
|
||||
with patch.object(
|
||||
SubprocessCLITransport,
|
||||
"_build_command",
|
||||
return_value=[
|
||||
sys.executable,
|
||||
test_script,
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--verbose",
|
||||
],
|
||||
"_find_cli",
|
||||
return_value=sys.executable
|
||||
):
|
||||
# Run query with async iterable
|
||||
messages = []
|
||||
async for msg in query(prompt=message_stream()):
|
||||
messages.append(msg)
|
||||
# Mock _build_command to add our test script as first argument
|
||||
original_build_command = SubprocessCLITransport._build_command
|
||||
|
||||
# Should get the result message
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], ResultMessage)
|
||||
assert messages[0].subtype == "success"
|
||||
def mock_build_command(self):
|
||||
# Get original command
|
||||
cmd = original_build_command(self)
|
||||
# Replace the CLI path with python + script
|
||||
cmd[0] = test_script
|
||||
return cmd
|
||||
|
||||
with patch.object(
|
||||
SubprocessCLITransport,
|
||||
"_build_command",
|
||||
mock_build_command
|
||||
):
|
||||
# Run query with async iterable
|
||||
messages = []
|
||||
async for msg in query(prompt=message_stream()):
|
||||
messages.append(msg)
|
||||
|
||||
# Should get the result message
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], ResultMessage)
|
||||
assert messages[0].subtype == "success"
|
||||
finally:
|
||||
# Clean up
|
||||
Path(test_script).unlink()
|
||||
|
||||
Reference in New Issue
Block a user