feat: Track costs and token counts in the class, across multiple runs (e.g. in agents)

This commit is contained in:
André Ferreira
2024-05-27 18:45:43 +02:00
parent c5ad62d8a0
commit 400e3c428d
2 changed files with 44 additions and 21 deletions

View File

@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, cast
from llama_index.core.callbacks.base_handler import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.llms import ChatMessage
from tokencost import calculate_prompt_cost, calculate_completion_cost
@@ -10,6 +11,10 @@ class TokenCostHandler(BaseCallbackHandler):
def __init__(self, model) -> None:
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
self.model = model
self.prompt_cost = 0
self.completion_cost = 0
self.prompt_tokens = 0
self.completion_tokens = 0
def start_trace(self, trace_id: Optional[str] = None) -> None:
return
@@ -22,26 +27,33 @@ class TokenCostHandler(BaseCallbackHandler):
return
def _calc_llm_event_cost(self, payload: dict) -> None:
from llama_index.llms import ChatMessage
prompt_cost = 0
completion_cost = 0
if EventPayload.PROMPT in payload:
prompt = str(payload.get(EventPayload.PROMPT))
completion = str(payload.get(EventPayload.COMPLETION))
prompt_cost = calculate_prompt_cost(prompt, self.model)
completion_cost = calculate_completion_cost(completion, self.model)
prompt_cost, prompt_tokens = calculate_prompt_cost(prompt, self.model)
completion_cost, completion_tokens = calculate_completion_cost(
completion, self.model
)
elif EventPayload.MESSAGES in payload:
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
messages_str = "\n".join([str(x) for x in messages])
prompt_cost = calculate_prompt_cost(messages_str, self.model)
prompt_cost, prompt_tokens = calculate_prompt_cost(messages_str, self.model)
response = str(payload.get(EventPayload.RESPONSE))
completion_cost = calculate_completion_cost(response, self.model)
completion_cost, completion_tokens = calculate_completion_cost(
response, self.model
)
print(f"# Prompt cost: {prompt_cost}")
print(f"# Completion: {completion_cost}")
print("\n")
self.prompt_cost += prompt_cost
self.completion_cost += completion_cost
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def reset_counts(self) -> None:
self.prompt_cost = 0
self.completion_cost = 0
self.prompt_tokens = 0
self.completion_tokens = 0
def on_event_start(
self,

View File

@@ -1,8 +1,9 @@
"""
Costs dictionary and utility tool for counting tokens
"""
import tiktoken
from typing import Union, List, Dict
from typing import Tuple, Union, List, Dict
from .constants import TOKEN_COSTS
from decimal import Decimal
import logging
@@ -57,10 +58,14 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
tokens_per_message = 4
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
logging.warning("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
logging.warning(
"gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
)
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logging.warning("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
logging.warning(
"gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
)
return count_message_tokens(messages, model="gpt-4-0613")
else:
raise KeyError(
@@ -118,13 +123,17 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> De
Double-check your spelling, or submit an issue/PR"""
)
cost_per_token_key = 'input_cost_per_token' if token_type == 'input' else 'output_cost_per_token'
cost_per_token_key = (
"input_cost_per_token" if token_type == "input" else "output_cost_per_token"
)
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
return Decimal(str(cost_per_token)) * Decimal(num_tokens)
def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal:
def calculate_prompt_cost(
prompt: Union[List[dict], str], model: str
) -> Tuple[Decimal, int]:
"""
Calculate the prompt's cost in USD.
@@ -133,7 +142,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
model (str): The model name.
Returns:
Decimal: The calculated cost in USD.
Tuple[Decimal, int]: The calculated cost in USD and number of tokens.
e.g.:
>>> prompt = [{ "role": "user", "content": "Hello world"},
@@ -164,10 +173,10 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
else count_message_tokens(prompt, model)
)
return calculate_cost_by_tokens(prompt_tokens, model, 'input')
return calculate_cost_by_tokens(prompt_tokens, model, "input"), prompt_tokens
def calculate_completion_cost(completion: str, model: str) -> Decimal:
def calculate_completion_cost(completion: str, model: str) -> Tuple[Decimal, int]:
"""
Calculate the prompt's cost in USD.
@@ -176,7 +185,7 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
model (str): The model name.
Returns:
Decimal: The calculated cost in USD.
Tuple[Decimal, int]: The calculated cost in USD and number of tokens.
e.g.:
>>> completion = "How may I assist you today?"
@@ -191,4 +200,6 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
)
completion_tokens = count_string_tokens(completion, model)
return calculate_cost_by_tokens(completion_tokens, model, 'output')
return calculate_cost_by_tokens(
completion_tokens, model, "output"
), completion_tokens