Merge pull request #47 from AndreCNF/fix/llama_index_callback_agents

fix: 🐛 Llama Index imports and track costs and token counts in the class
This commit is contained in:
Alex Reibman
2024-06-03 22:44:57 -07:00
committed by GitHub
7 changed files with 122 additions and 62 deletions

View File

@@ -506,7 +506,7 @@ pip install `'tokencost[llama-index]'`
To use the base callback handler, you may import it:
```python
from tokencost.callbacks.llama_index import BaseCallbackHandler
from tokencost.callbacks.llama_index import TokenCostHandler
```
and pass to your framework callback handler.

View File

@@ -35,7 +35,7 @@ dev = [
"coverage[toml]>=7.4.0",
]
llama-index = [
"llama-index>=0.9.24"
"llama-index>=0.10.23"
]
[project.urls]

View File

@@ -131,21 +131,21 @@ def test_count_string_invalid_model():
@pytest.mark.parametrize(
"prompt,model,expected_output",
[
(MESSAGES, "gpt-3.5-turbo", Decimal('0.0000225')),
(MESSAGES, "gpt-3.5-turbo-0301", Decimal('0.0000255')),
(MESSAGES, "gpt-3.5-turbo-0613", Decimal('0.0000225')),
(MESSAGES, "gpt-3.5-turbo-16k", Decimal('0.000045')),
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal('0.000045')),
(MESSAGES, "gpt-3.5-turbo-1106", Decimal('0.000015')),
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal('0.0000225')),
(MESSAGES, "gpt-4", Decimal('0.00045')),
(MESSAGES, "gpt-4-0314", Decimal('0.00045')),
(MESSAGES, "gpt-4-32k", Decimal('0.00090')),
(MESSAGES, "gpt-4-32k-0314", Decimal('0.00090')),
(MESSAGES, "gpt-4-0613", Decimal('0.00045')),
(MESSAGES, "gpt-4-1106-preview", Decimal('0.00015')),
(MESSAGES, "gpt-4-vision-preview", Decimal('0.00015')),
(STRING, "text-embedding-ada-002", Decimal('0.0000004')),
(MESSAGES, "gpt-3.5-turbo", Decimal("0.0000225")),
(MESSAGES, "gpt-3.5-turbo-0301", Decimal("0.0000255")),
(MESSAGES, "gpt-3.5-turbo-0613", Decimal("0.0000225")),
(MESSAGES, "gpt-3.5-turbo-16k", Decimal("0.000045")),
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal("0.000045")),
(MESSAGES, "gpt-3.5-turbo-1106", Decimal("0.000015")),
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal("0.0000225")),
(MESSAGES, "gpt-4", Decimal("0.00045")),
(MESSAGES, "gpt-4-0314", Decimal("0.00045")),
(MESSAGES, "gpt-4-32k", Decimal("0.00090")),
(MESSAGES, "gpt-4-32k-0314", Decimal("0.00090")),
(MESSAGES, "gpt-4-0613", Decimal("0.00045")),
(MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")),
(MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")),
(STRING, "text-embedding-ada-002", Decimal("0.0000004")),
],
)
def test_calculate_prompt_cost(prompt, model, expected_output):
@@ -165,20 +165,20 @@ def test_invalid_prompt_format():
@pytest.mark.parametrize(
"prompt,model,expected_output",
[
(STRING, "gpt-3.5-turbo", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-0301", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-0613", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-16k", Decimal('0.000016')),
(STRING, "gpt-3.5-turbo-16k-0613", Decimal('0.000016')),
(STRING, "gpt-3.5-turbo-1106", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-instruct", Decimal('0.000008')),
(STRING, "gpt-4", Decimal('0.00024')),
(STRING, "gpt-4-0314", Decimal('0.00024')),
(STRING, "gpt-4-32k", Decimal('0.00048')),
(STRING, "gpt-4-32k-0314", Decimal('0.00048')),
(STRING, "gpt-4-0613", Decimal('0.00024')),
(STRING, "gpt-4-1106-preview", Decimal('0.00012')),
(STRING, "gpt-4-vision-preview", Decimal('0.00012')),
(STRING, "gpt-3.5-turbo", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-0301", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-0613", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-16k", Decimal("0.000016")),
(STRING, "gpt-3.5-turbo-16k-0613", Decimal("0.000016")),
(STRING, "gpt-3.5-turbo-1106", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-instruct", Decimal("0.000008")),
(STRING, "gpt-4", Decimal("0.00024")),
(STRING, "gpt-4-0314", Decimal("0.00024")),
(STRING, "gpt-4-32k", Decimal("0.00048")),
(STRING, "gpt-4-32k-0314", Decimal("0.00048")),
(STRING, "gpt-4-0613", Decimal("0.00024")),
(STRING, "gpt-4-1106-preview", Decimal("0.00012")),
(STRING, "gpt-4-vision-preview", Decimal("0.00012")),
(STRING, "text-embedding-ada-002", 0),
],
)
@@ -213,9 +213,9 @@ def test_calculate_invalid_input_types():
@pytest.mark.parametrize(
"num_tokens,model,token_type,expected_output",
[
(10, "gpt-3.5-turbo", 'input', Decimal('0.0000150')), # Example values
(5, "gpt-4", 'output', Decimal('0.00030')), # Example values
(10, "ai21.j2-mid-v1", 'input', Decimal('0.0001250')), # Example values
(10, "gpt-3.5-turbo", "input", Decimal("0.0000150")), # Example values
(5, "gpt-4", "output", Decimal("0.00030")), # Example values
(10, "ai21.j2-mid-v1", "input", Decimal("0.0001250")), # Example values
],
)
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):

View File

@@ -1,7 +1,7 @@
# test_llama_index.py
import pytest
from tokencost.callbacks import llama_index
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from unittest.mock import MagicMock
# Mock the calculate_prompt_cost and calculate_completion_cost functions
@@ -20,36 +20,39 @@ def mock_chat_message(monkeypatch):
def __str__(self):
return self.text
monkeypatch.setattr('llama_index.llms.ChatMessage', MockChatMessage)
monkeypatch.setattr("llama_index.core.llms.ChatMessage", MockChatMessage)
return MockChatMessage
# Test the _calc_llm_event_cost method for prompt and completion
def test_calc_llm_event_cost_prompt_completion(capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
payload = {
EventPayload.PROMPT: STRING,
EventPayload.COMPLETION: STRING
}
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
payload = {EventPayload.PROMPT: STRING, EventPayload.COMPLETION: STRING}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 0.0000060" in captured.out
assert "# Completion: 0.000008" in captured.out
# Test the _calc_llm_event_cost method for messages and response
def test_calc_llm_event_cost_messages_response(mock_chat_message, capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
payload = {
EventPayload.MESSAGES: [mock_chat_message("message 1"), mock_chat_message("message 2")],
EventPayload.RESPONSE: "test response"
EventPayload.MESSAGES: [
mock_chat_message("message 1"),
mock_chat_message("message 2"),
],
EventPayload.RESPONSE: "test response",
}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 0.0000105" in captured.out
assert "# Completion: 0.000004" in captured.out
# Additional tests can be written for start_trace, end_trace, on_event_start, and on_event_end
# depending on the specific logic and requirements of those methods.

View File

@@ -3,5 +3,6 @@ from .costs import (
count_string_tokens,
calculate_completion_cost,
calculate_prompt_cost,
calculate_all_costs_and_tokens,
)
from .constants import TOKEN_COSTS_STATIC, TOKEN_COSTS, update_token_costs

View File

@@ -1,7 +1,8 @@
from typing import Any, Dict, List, Optional, cast
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from tokencost import calculate_prompt_cost, calculate_completion_cost
from llama_index.core.callbacks.base_handler import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.llms import ChatMessage
from tokencost import calculate_all_costs_and_tokens
class TokenCostHandler(BaseCallbackHandler):
@@ -10,6 +11,10 @@ class TokenCostHandler(BaseCallbackHandler):
def __init__(self, model) -> None:
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
self.model = model
self.prompt_cost = 0
self.completion_cost = 0
self.prompt_tokens = 0
self.completion_tokens = 0
def start_trace(self, trace_id: Optional[str] = None) -> None:
return
@@ -22,27 +27,34 @@ class TokenCostHandler(BaseCallbackHandler):
return
def _calc_llm_event_cost(self, payload: dict) -> None:
from llama_index.llms import ChatMessage
prompt_cost = 0
completion_cost = 0
if EventPayload.PROMPT in payload:
prompt = str(payload.get(EventPayload.PROMPT))
completion = str(payload.get(EventPayload.COMPLETION))
prompt_cost = calculate_prompt_cost(prompt, self.model)
completion_cost = calculate_completion_cost(completion, self.model)
estimates = calculate_all_costs_and_tokens(prompt, completion, self.model)
elif EventPayload.MESSAGES in payload:
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
messages_str = "\n".join([str(x) for x in messages])
prompt_cost = calculate_prompt_cost(messages_str, self.model)
response = str(payload.get(EventPayload.RESPONSE))
completion_cost = calculate_completion_cost(response, self.model)
estimates = calculate_all_costs_and_tokens(
messages_str, response, self.model
)
print(f"# Prompt cost: {prompt_cost}")
print(f"# Completion: {completion_cost}")
self.prompt_cost += estimates["prompt_cost"]
self.completion_cost += estimates["completion_cost"]
self.prompt_tokens += estimates["prompt_tokens"]
self.completion_tokens += estimates["completion_tokens"]
print(f"# Prompt cost: {estimates['prompt_cost']}")
print(f"# Completion: {estimates['completion_cost']}")
print("\n")
def reset_counts(self) -> None:
self.prompt_cost = 0
self.completion_cost = 0
self.prompt_tokens = 0
self.completion_tokens = 0
def on_event_start(
self,
event_type: CBEventType,

View File

@@ -1,6 +1,7 @@
"""
Costs dictionary and utility tool for counting tokens
"""
import tiktoken
from typing import Union, List, Dict
from .constants import TOKEN_COSTS
@@ -57,10 +58,14 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
tokens_per_message = 4
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
logging.warning("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
logging.warning(
"gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
)
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logging.warning("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
logging.warning(
"gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
)
return count_message_tokens(messages, model="gpt-4-0613")
else:
raise KeyError(
@@ -118,7 +123,9 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> De
Double-check your spelling, or submit an issue/PR"""
)
cost_per_token_key = 'input_cost_per_token' if token_type == 'input' else 'output_cost_per_token'
cost_per_token_key = (
"input_cost_per_token" if token_type == "input" else "output_cost_per_token"
)
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
return Decimal(str(cost_per_token)) * Decimal(num_tokens)
@@ -164,7 +171,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
else count_message_tokens(prompt, model)
)
return calculate_cost_by_tokens(prompt_tokens, model, 'input')
return calculate_cost_by_tokens(prompt_tokens, model, "input")
def calculate_completion_cost(completion: str, model: str) -> Decimal:
@@ -191,4 +198,41 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
)
completion_tokens = count_string_tokens(completion, model)
return calculate_cost_by_tokens(completion_tokens, model, 'output')
return calculate_cost_by_tokens(completion_tokens, model, "output")
def calculate_all_costs_and_tokens(
prompt: Union[List[dict], str], completion: str, model: str
) -> dict:
"""
Calculate the prompt and completion costs and tokens in USD.
Args:
prompt (Union[List[dict], str]): List of message objects or single string prompt.
completion (str): Completion string.
model (str): The model name.
Returns:
dict: The calculated cost and tokens in USD.
e.g.:
>>> prompt = "Hello world"
>>> completion = "How may I assist you today?"
>>> calculate_all_costs_and_tokens(prompt, completion, "gpt-3.5-turbo")
{'prompt_cost': Decimal('0.0000030'), 'prompt_tokens': 2, 'completion_cost': Decimal('0.000014'), 'completion_tokens': 7}
"""
prompt_cost = calculate_prompt_cost(prompt, model)
completion_cost = calculate_completion_cost(completion, model)
prompt_tokens = (
count_string_tokens(prompt, model)
if isinstance(prompt, str)
else count_message_tokens(prompt, model)
)
completion_tokens = count_string_tokens(completion, model)
return {
"prompt_cost": prompt_cost,
"prompt_tokens": prompt_tokens,
"completion_cost": completion_cost,
"completion_tokens": completion_tokens,
}