mirror of
https://github.com/AgentOps-AI/tokencost.git
synced 2024-06-22 04:30:40 +03:00
Merge pull request #47 from AndreCNF/fix/llama_index_callback_agents
fix: 🐛 Llama Index imports and track costs and token counts in the class
This commit is contained in:
@@ -506,7 +506,7 @@ pip install `'tokencost[llama-index]'`
|
||||
To use the base callback handler, you may import it:
|
||||
|
||||
```python
|
||||
from tokencost.callbacks.llama_index import BaseCallbackHandler
|
||||
from tokencost.callbacks.llama_index import TokenCostHandler
|
||||
```
|
||||
|
||||
and pass to your framework callback handler.
|
||||
|
||||
@@ -35,7 +35,7 @@ dev = [
|
||||
"coverage[toml]>=7.4.0",
|
||||
]
|
||||
llama-index = [
|
||||
"llama-index>=0.9.24"
|
||||
"llama-index>=0.10.23"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -131,21 +131,21 @@ def test_count_string_invalid_model():
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,model,expected_output",
|
||||
[
|
||||
(MESSAGES, "gpt-3.5-turbo", Decimal('0.0000225')),
|
||||
(MESSAGES, "gpt-3.5-turbo-0301", Decimal('0.0000255')),
|
||||
(MESSAGES, "gpt-3.5-turbo-0613", Decimal('0.0000225')),
|
||||
(MESSAGES, "gpt-3.5-turbo-16k", Decimal('0.000045')),
|
||||
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal('0.000045')),
|
||||
(MESSAGES, "gpt-3.5-turbo-1106", Decimal('0.000015')),
|
||||
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal('0.0000225')),
|
||||
(MESSAGES, "gpt-4", Decimal('0.00045')),
|
||||
(MESSAGES, "gpt-4-0314", Decimal('0.00045')),
|
||||
(MESSAGES, "gpt-4-32k", Decimal('0.00090')),
|
||||
(MESSAGES, "gpt-4-32k-0314", Decimal('0.00090')),
|
||||
(MESSAGES, "gpt-4-0613", Decimal('0.00045')),
|
||||
(MESSAGES, "gpt-4-1106-preview", Decimal('0.00015')),
|
||||
(MESSAGES, "gpt-4-vision-preview", Decimal('0.00015')),
|
||||
(STRING, "text-embedding-ada-002", Decimal('0.0000004')),
|
||||
(MESSAGES, "gpt-3.5-turbo", Decimal("0.0000225")),
|
||||
(MESSAGES, "gpt-3.5-turbo-0301", Decimal("0.0000255")),
|
||||
(MESSAGES, "gpt-3.5-turbo-0613", Decimal("0.0000225")),
|
||||
(MESSAGES, "gpt-3.5-turbo-16k", Decimal("0.000045")),
|
||||
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal("0.000045")),
|
||||
(MESSAGES, "gpt-3.5-turbo-1106", Decimal("0.000015")),
|
||||
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal("0.0000225")),
|
||||
(MESSAGES, "gpt-4", Decimal("0.00045")),
|
||||
(MESSAGES, "gpt-4-0314", Decimal("0.00045")),
|
||||
(MESSAGES, "gpt-4-32k", Decimal("0.00090")),
|
||||
(MESSAGES, "gpt-4-32k-0314", Decimal("0.00090")),
|
||||
(MESSAGES, "gpt-4-0613", Decimal("0.00045")),
|
||||
(MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")),
|
||||
(MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")),
|
||||
(STRING, "text-embedding-ada-002", Decimal("0.0000004")),
|
||||
],
|
||||
)
|
||||
def test_calculate_prompt_cost(prompt, model, expected_output):
|
||||
@@ -165,20 +165,20 @@ def test_invalid_prompt_format():
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,model,expected_output",
|
||||
[
|
||||
(STRING, "gpt-3.5-turbo", Decimal('0.000008')),
|
||||
(STRING, "gpt-3.5-turbo-0301", Decimal('0.000008')),
|
||||
(STRING, "gpt-3.5-turbo-0613", Decimal('0.000008')),
|
||||
(STRING, "gpt-3.5-turbo-16k", Decimal('0.000016')),
|
||||
(STRING, "gpt-3.5-turbo-16k-0613", Decimal('0.000016')),
|
||||
(STRING, "gpt-3.5-turbo-1106", Decimal('0.000008')),
|
||||
(STRING, "gpt-3.5-turbo-instruct", Decimal('0.000008')),
|
||||
(STRING, "gpt-4", Decimal('0.00024')),
|
||||
(STRING, "gpt-4-0314", Decimal('0.00024')),
|
||||
(STRING, "gpt-4-32k", Decimal('0.00048')),
|
||||
(STRING, "gpt-4-32k-0314", Decimal('0.00048')),
|
||||
(STRING, "gpt-4-0613", Decimal('0.00024')),
|
||||
(STRING, "gpt-4-1106-preview", Decimal('0.00012')),
|
||||
(STRING, "gpt-4-vision-preview", Decimal('0.00012')),
|
||||
(STRING, "gpt-3.5-turbo", Decimal("0.000008")),
|
||||
(STRING, "gpt-3.5-turbo-0301", Decimal("0.000008")),
|
||||
(STRING, "gpt-3.5-turbo-0613", Decimal("0.000008")),
|
||||
(STRING, "gpt-3.5-turbo-16k", Decimal("0.000016")),
|
||||
(STRING, "gpt-3.5-turbo-16k-0613", Decimal("0.000016")),
|
||||
(STRING, "gpt-3.5-turbo-1106", Decimal("0.000008")),
|
||||
(STRING, "gpt-3.5-turbo-instruct", Decimal("0.000008")),
|
||||
(STRING, "gpt-4", Decimal("0.00024")),
|
||||
(STRING, "gpt-4-0314", Decimal("0.00024")),
|
||||
(STRING, "gpt-4-32k", Decimal("0.00048")),
|
||||
(STRING, "gpt-4-32k-0314", Decimal("0.00048")),
|
||||
(STRING, "gpt-4-0613", Decimal("0.00024")),
|
||||
(STRING, "gpt-4-1106-preview", Decimal("0.00012")),
|
||||
(STRING, "gpt-4-vision-preview", Decimal("0.00012")),
|
||||
(STRING, "text-embedding-ada-002", 0),
|
||||
],
|
||||
)
|
||||
@@ -213,9 +213,9 @@ def test_calculate_invalid_input_types():
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,model,token_type,expected_output",
|
||||
[
|
||||
(10, "gpt-3.5-turbo", 'input', Decimal('0.0000150')), # Example values
|
||||
(5, "gpt-4", 'output', Decimal('0.00030')), # Example values
|
||||
(10, "ai21.j2-mid-v1", 'input', Decimal('0.0001250')), # Example values
|
||||
(10, "gpt-3.5-turbo", "input", Decimal("0.0000150")), # Example values
|
||||
(5, "gpt-4", "output", Decimal("0.00030")), # Example values
|
||||
(10, "ai21.j2-mid-v1", "input", Decimal("0.0001250")), # Example values
|
||||
],
|
||||
)
|
||||
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# test_llama_index.py
|
||||
import pytest
|
||||
from tokencost.callbacks import llama_index
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from llama_index.core.callbacks.schema import CBEventType, EventPayload
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Mock the calculate_prompt_cost and calculate_completion_cost functions
|
||||
@@ -20,36 +20,39 @@ def mock_chat_message(monkeypatch):
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
monkeypatch.setattr('llama_index.llms.ChatMessage', MockChatMessage)
|
||||
monkeypatch.setattr("llama_index.core.llms.ChatMessage", MockChatMessage)
|
||||
return MockChatMessage
|
||||
|
||||
|
||||
# Test the _calc_llm_event_cost method for prompt and completion
|
||||
|
||||
|
||||
def test_calc_llm_event_cost_prompt_completion(capsys):
|
||||
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
|
||||
payload = {
|
||||
EventPayload.PROMPT: STRING,
|
||||
EventPayload.COMPLETION: STRING
|
||||
}
|
||||
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
|
||||
payload = {EventPayload.PROMPT: STRING, EventPayload.COMPLETION: STRING}
|
||||
handler._calc_llm_event_cost(payload)
|
||||
captured = capsys.readouterr()
|
||||
assert "# Prompt cost: 0.0000060" in captured.out
|
||||
assert "# Completion: 0.000008" in captured.out
|
||||
|
||||
|
||||
# Test the _calc_llm_event_cost method for messages and response
|
||||
|
||||
|
||||
def test_calc_llm_event_cost_messages_response(mock_chat_message, capsys):
|
||||
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
|
||||
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
|
||||
payload = {
|
||||
EventPayload.MESSAGES: [mock_chat_message("message 1"), mock_chat_message("message 2")],
|
||||
EventPayload.RESPONSE: "test response"
|
||||
EventPayload.MESSAGES: [
|
||||
mock_chat_message("message 1"),
|
||||
mock_chat_message("message 2"),
|
||||
],
|
||||
EventPayload.RESPONSE: "test response",
|
||||
}
|
||||
handler._calc_llm_event_cost(payload)
|
||||
captured = capsys.readouterr()
|
||||
assert "# Prompt cost: 0.0000105" in captured.out
|
||||
assert "# Completion: 0.000004" in captured.out
|
||||
|
||||
|
||||
# Additional tests can be written for start_trace, end_trace, on_event_start, and on_event_end
|
||||
# depending on the specific logic and requirements of those methods.
|
||||
|
||||
@@ -3,5 +3,6 @@ from .costs import (
|
||||
count_string_tokens,
|
||||
calculate_completion_cost,
|
||||
calculate_prompt_cost,
|
||||
calculate_all_costs_and_tokens,
|
||||
)
|
||||
from .constants import TOKEN_COSTS_STATIC, TOKEN_COSTS, update_token_costs
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from tokencost import calculate_prompt_cost, calculate_completion_cost
|
||||
from llama_index.core.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType, EventPayload
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from tokencost import calculate_all_costs_and_tokens
|
||||
|
||||
|
||||
class TokenCostHandler(BaseCallbackHandler):
|
||||
@@ -10,6 +11,10 @@ class TokenCostHandler(BaseCallbackHandler):
|
||||
def __init__(self, model) -> None:
|
||||
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
||||
self.model = model
|
||||
self.prompt_cost = 0
|
||||
self.completion_cost = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
return
|
||||
@@ -22,27 +27,34 @@ class TokenCostHandler(BaseCallbackHandler):
|
||||
return
|
||||
|
||||
def _calc_llm_event_cost(self, payload: dict) -> None:
|
||||
from llama_index.llms import ChatMessage
|
||||
|
||||
prompt_cost = 0
|
||||
completion_cost = 0
|
||||
if EventPayload.PROMPT in payload:
|
||||
prompt = str(payload.get(EventPayload.PROMPT))
|
||||
completion = str(payload.get(EventPayload.COMPLETION))
|
||||
prompt_cost = calculate_prompt_cost(prompt, self.model)
|
||||
completion_cost = calculate_completion_cost(completion, self.model)
|
||||
estimates = calculate_all_costs_and_tokens(prompt, completion, self.model)
|
||||
|
||||
elif EventPayload.MESSAGES in payload:
|
||||
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
|
||||
messages_str = "\n".join([str(x) for x in messages])
|
||||
prompt_cost = calculate_prompt_cost(messages_str, self.model)
|
||||
response = str(payload.get(EventPayload.RESPONSE))
|
||||
completion_cost = calculate_completion_cost(response, self.model)
|
||||
estimates = calculate_all_costs_and_tokens(
|
||||
messages_str, response, self.model
|
||||
)
|
||||
|
||||
print(f"# Prompt cost: {prompt_cost}")
|
||||
print(f"# Completion: {completion_cost}")
|
||||
self.prompt_cost += estimates["prompt_cost"]
|
||||
self.completion_cost += estimates["completion_cost"]
|
||||
self.prompt_tokens += estimates["prompt_tokens"]
|
||||
self.completion_tokens += estimates["completion_tokens"]
|
||||
|
||||
print(f"# Prompt cost: {estimates['prompt_cost']}")
|
||||
print(f"# Completion: {estimates['completion_cost']}")
|
||||
print("\n")
|
||||
|
||||
def reset_counts(self) -> None:
|
||||
self.prompt_cost = 0
|
||||
self.completion_cost = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Costs dictionary and utility tool for counting tokens
|
||||
"""
|
||||
|
||||
import tiktoken
|
||||
from typing import Union, List, Dict
|
||||
from .constants import TOKEN_COSTS
|
||||
@@ -57,10 +58,14 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
|
||||
tokens_per_message = 4
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
logging.warning("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
logging.warning(
|
||||
"gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
|
||||
)
|
||||
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
logging.warning("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
logging.warning(
|
||||
"gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
|
||||
)
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise KeyError(
|
||||
@@ -118,7 +123,9 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> De
|
||||
Double-check your spelling, or submit an issue/PR"""
|
||||
)
|
||||
|
||||
cost_per_token_key = 'input_cost_per_token' if token_type == 'input' else 'output_cost_per_token'
|
||||
cost_per_token_key = (
|
||||
"input_cost_per_token" if token_type == "input" else "output_cost_per_token"
|
||||
)
|
||||
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
|
||||
|
||||
return Decimal(str(cost_per_token)) * Decimal(num_tokens)
|
||||
@@ -164,7 +171,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
|
||||
else count_message_tokens(prompt, model)
|
||||
)
|
||||
|
||||
return calculate_cost_by_tokens(prompt_tokens, model, 'input')
|
||||
return calculate_cost_by_tokens(prompt_tokens, model, "input")
|
||||
|
||||
|
||||
def calculate_completion_cost(completion: str, model: str) -> Decimal:
|
||||
@@ -191,4 +198,41 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
|
||||
)
|
||||
completion_tokens = count_string_tokens(completion, model)
|
||||
|
||||
return calculate_cost_by_tokens(completion_tokens, model, 'output')
|
||||
return calculate_cost_by_tokens(completion_tokens, model, "output")
|
||||
|
||||
|
||||
def calculate_all_costs_and_tokens(
|
||||
prompt: Union[List[dict], str], completion: str, model: str
|
||||
) -> dict:
|
||||
"""
|
||||
Calculate the prompt and completion costs and tokens in USD.
|
||||
|
||||
Args:
|
||||
prompt (Union[List[dict], str]): List of message objects or single string prompt.
|
||||
completion (str): Completion string.
|
||||
model (str): The model name.
|
||||
|
||||
Returns:
|
||||
dict: The calculated cost and tokens in USD.
|
||||
|
||||
e.g.:
|
||||
>>> prompt = "Hello world"
|
||||
>>> completion = "How may I assist you today?"
|
||||
>>> calculate_all_costs_and_tokens(prompt, completion, "gpt-3.5-turbo")
|
||||
{'prompt_cost': Decimal('0.0000030'), 'prompt_tokens': 2, 'completion_cost': Decimal('0.000014'), 'completion_tokens': 7}
|
||||
"""
|
||||
prompt_cost = calculate_prompt_cost(prompt, model)
|
||||
completion_cost = calculate_completion_cost(completion, model)
|
||||
prompt_tokens = (
|
||||
count_string_tokens(prompt, model)
|
||||
if isinstance(prompt, str)
|
||||
else count_message_tokens(prompt, model)
|
||||
)
|
||||
completion_tokens = count_string_tokens(completion, model)
|
||||
|
||||
return {
|
||||
"prompt_cost": prompt_cost,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_cost": completion_cost,
|
||||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user