mirror of
https://github.com/simonw/llm-claude-3.git
synced 2025-01-23 19:28:29 +03:00
Use response.set_usage(), closes #29
Refs https://github.com/simonw/llm/issues/610
This commit is contained in:
@@ -231,6 +231,13 @@ class _Shared:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
return kwargs
|
||||
|
||||
def set_usage(self, response):
|
||||
usage = response.response_json.pop("usage")
|
||||
if usage:
|
||||
response.set_usage(
|
||||
input=usage.get("input_tokens"), output=usage.get("output_tokens")
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "Anthropic Messages: {}".format(self.model_id)
|
||||
|
||||
@@ -250,6 +257,7 @@ class ClaudeMessages(_Shared, llm.Model):
|
||||
completion = client.messages.create(**kwargs)
|
||||
yield completion.content[0].text
|
||||
response.response_json = completion.model_dump()
|
||||
self.set_usage(response)
|
||||
|
||||
|
||||
class ClaudeMessagesLong(ClaudeMessages):
|
||||
@@ -270,6 +278,7 @@ class AsyncClaudeMessages(_Shared, llm.AsyncModel):
|
||||
completion = await client.messages.create(**kwargs)
|
||||
yield completion.content[0].text
|
||||
response.response_json = completion.model_dump()
|
||||
self.set_usage(response)
|
||||
|
||||
|
||||
class AsyncClaudeMessagesLong(AsyncClaudeMessages):
|
||||
|
||||
@@ -9,7 +9,7 @@ classifiers = [
|
||||
"License :: OSI Approved :: Apache Software License"
|
||||
]
|
||||
dependencies = [
|
||||
"llm>=0.18",
|
||||
"llm>=0.19a0",
|
||||
"anthropic>=0.39.0",
|
||||
]
|
||||
|
||||
|
||||
@@ -30,8 +30,10 @@ def test_prompt():
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {"input_tokens": 17, "output_tokens": 15},
|
||||
}
|
||||
assert response.input_tokens == 17
|
||||
assert response.output_tokens == 15
|
||||
assert response.token_details is None
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
@@ -50,8 +52,10 @@ async def test_async_prompt():
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {"input_tokens": 17, "output_tokens": 15},
|
||||
}
|
||||
assert response.input_tokens == 17
|
||||
assert response.output_tokens == 15
|
||||
assert response.token_details is None
|
||||
|
||||
|
||||
EXPECTED_IMAGE_TEXT = (
|
||||
@@ -86,5 +90,7 @@ def test_image_prompt():
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {"input_tokens": 76, "output_tokens": 75},
|
||||
}
|
||||
assert response.input_tokens == 76
|
||||
assert response.output_tokens == 75
|
||||
assert response.token_details is None
|
||||
|
||||
Reference in New Issue
Block a user