From fd898ff2b52902889085e78eaf8381961e4620ff Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 19 Nov 2024 20:28:28 -0800 Subject: [PATCH] Use response.set_usage(), closes #29 Refs https://github.com/simonw/llm/issues/610 --- llm_claude_3.py | 9 +++++++++ pyproject.toml | 2 +- tests/test_claude_3.py | 12 +++++++++--- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/llm_claude_3.py b/llm_claude_3.py index a05b01b..0a6e236 100644 --- a/llm_claude_3.py +++ b/llm_claude_3.py @@ -231,6 +231,13 @@ class _Shared: kwargs["extra_headers"] = self.extra_headers return kwargs + def set_usage(self, response): + usage = response.response_json.pop("usage") + if usage: + response.set_usage( + input=usage.get("input_tokens"), output=usage.get("output_tokens") + ) + def __str__(self): return "Anthropic Messages: {}".format(self.model_id) @@ -250,6 +257,7 @@ class ClaudeMessages(_Shared, llm.Model): completion = client.messages.create(**kwargs) yield completion.content[0].text response.response_json = completion.model_dump() + self.set_usage(response) class ClaudeMessagesLong(ClaudeMessages): @@ -270,6 +278,7 @@ class AsyncClaudeMessages(_Shared, llm.AsyncModel): completion = await client.messages.create(**kwargs) yield completion.content[0].text response.response_json = completion.model_dump() + self.set_usage(response) class AsyncClaudeMessagesLong(AsyncClaudeMessages): diff --git a/pyproject.toml b/pyproject.toml index c9ca586..0bfe84e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "llm>=0.18", + "llm>=0.19a0", "anthropic>=0.39.0", ] diff --git a/tests/test_claude_3.py b/tests/test_claude_3.py index 8c158ca..258a006 100644 --- a/tests/test_claude_3.py +++ b/tests/test_claude_3.py @@ -30,8 +30,10 @@ def test_prompt(): "stop_reason": "end_turn", "stop_sequence": None, "type": "message", - "usage": {"input_tokens": 17, "output_tokens": 15}, } + assert response.input_tokens == 17 + assert response.output_tokens == 15 + assert response.token_details is None @pytest.mark.vcr @@ -50,8 +52,10 @@ async def test_async_prompt(): "stop_reason": "end_turn", "stop_sequence": None, "type": "message", - "usage": {"input_tokens": 17, "output_tokens": 15}, } + assert response.input_tokens == 17 + assert response.output_tokens == 15 + assert response.token_details is None EXPECTED_IMAGE_TEXT = ( @@ -86,5 +90,7 @@ def test_image_prompt(): "stop_reason": "end_turn", "stop_sequence": None, "type": "message", - "usage": {"input_tokens": 76, "output_tokens": 75}, } + assert response.input_tokens == 76 + assert response.output_tokens == 75 + assert response.token_details is None