mirror of
https://github.com/AgentOps-AI/tokencost.git
synced 2024-06-22 04:30:40 +03:00
feat: ✨ Track costs and token counts in the class, across multiple runs (e.g. in agents)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from llama_index.core.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.core.callbacks.schema import CBEventType, EventPayload
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from tokencost import calculate_prompt_cost, calculate_completion_cost
|
||||
|
||||
|
||||
@@ -10,6 +11,10 @@ class TokenCostHandler(BaseCallbackHandler):
|
||||
def __init__(self, model) -> None:
|
||||
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
||||
self.model = model
|
||||
self.prompt_cost = 0
|
||||
self.completion_cost = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
return
|
||||
@@ -22,26 +27,33 @@ class TokenCostHandler(BaseCallbackHandler):
|
||||
return
|
||||
|
||||
def _calc_llm_event_cost(self, payload: dict) -> None:
|
||||
from llama_index.llms import ChatMessage
|
||||
|
||||
prompt_cost = 0
|
||||
completion_cost = 0
|
||||
if EventPayload.PROMPT in payload:
|
||||
prompt = str(payload.get(EventPayload.PROMPT))
|
||||
completion = str(payload.get(EventPayload.COMPLETION))
|
||||
prompt_cost = calculate_prompt_cost(prompt, self.model)
|
||||
completion_cost = calculate_completion_cost(completion, self.model)
|
||||
prompt_cost, prompt_tokens = calculate_prompt_cost(prompt, self.model)
|
||||
completion_cost, completion_tokens = calculate_completion_cost(
|
||||
completion, self.model
|
||||
)
|
||||
|
||||
elif EventPayload.MESSAGES in payload:
|
||||
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
|
||||
messages_str = "\n".join([str(x) for x in messages])
|
||||
prompt_cost = calculate_prompt_cost(messages_str, self.model)
|
||||
prompt_cost, prompt_tokens = calculate_prompt_cost(messages_str, self.model)
|
||||
response = str(payload.get(EventPayload.RESPONSE))
|
||||
completion_cost = calculate_completion_cost(response, self.model)
|
||||
completion_cost, completion_tokens = calculate_completion_cost(
|
||||
response, self.model
|
||||
)
|
||||
|
||||
print(f"# Prompt cost: {prompt_cost}")
|
||||
print(f"# Completion: {completion_cost}")
|
||||
print("\n")
|
||||
self.prompt_cost += prompt_cost
|
||||
self.completion_cost += completion_cost
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
|
||||
def reset_counts(self) -> None:
|
||||
self.prompt_cost = 0
|
||||
self.completion_cost = 0
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
Costs dictionary and utility tool for counting tokens
|
||||
"""
|
||||
|
||||
import tiktoken
|
||||
from typing import Union, List, Dict
|
||||
from typing import Tuple, Union, List, Dict
|
||||
from .constants import TOKEN_COSTS
|
||||
from decimal import Decimal
|
||||
import logging
|
||||
@@ -57,10 +58,14 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
|
||||
tokens_per_message = 4
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
logging.warning("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
logging.warning(
|
||||
"gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
|
||||
)
|
||||
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
logging.warning("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
logging.warning(
|
||||
"gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
|
||||
)
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise KeyError(
|
||||
@@ -118,13 +123,17 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> De
|
||||
Double-check your spelling, or submit an issue/PR"""
|
||||
)
|
||||
|
||||
cost_per_token_key = 'input_cost_per_token' if token_type == 'input' else 'output_cost_per_token'
|
||||
cost_per_token_key = (
|
||||
"input_cost_per_token" if token_type == "input" else "output_cost_per_token"
|
||||
)
|
||||
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
|
||||
|
||||
return Decimal(str(cost_per_token)) * Decimal(num_tokens)
|
||||
|
||||
|
||||
def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal:
|
||||
def calculate_prompt_cost(
|
||||
prompt: Union[List[dict], str], model: str
|
||||
) -> Tuple[Decimal, int]:
|
||||
"""
|
||||
Calculate the prompt's cost in USD.
|
||||
|
||||
@@ -133,7 +142,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
|
||||
model (str): The model name.
|
||||
|
||||
Returns:
|
||||
Decimal: The calculated cost in USD.
|
||||
Tuple[Decimal, int]: The calculated cost in USD and number of tokens.
|
||||
|
||||
e.g.:
|
||||
>>> prompt = [{ "role": "user", "content": "Hello world"},
|
||||
@@ -164,10 +173,10 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
|
||||
else count_message_tokens(prompt, model)
|
||||
)
|
||||
|
||||
return calculate_cost_by_tokens(prompt_tokens, model, 'input')
|
||||
return calculate_cost_by_tokens(prompt_tokens, model, "input"), prompt_tokens
|
||||
|
||||
|
||||
def calculate_completion_cost(completion: str, model: str) -> Decimal:
|
||||
def calculate_completion_cost(completion: str, model: str) -> Tuple[Decimal, int]:
|
||||
"""
|
||||
Calculate the prompt's cost in USD.
|
||||
|
||||
@@ -176,7 +185,7 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
|
||||
model (str): The model name.
|
||||
|
||||
Returns:
|
||||
Decimal: The calculated cost in USD.
|
||||
Tuple[Decimal, int]: The calculated cost in USD and number of tokens.
|
||||
|
||||
e.g.:
|
||||
>>> completion = "How may I assist you today?"
|
||||
@@ -191,4 +200,6 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
|
||||
)
|
||||
completion_tokens = count_string_tokens(completion, model)
|
||||
|
||||
return calculate_cost_by_tokens(completion_tokens, model, 'output')
|
||||
return calculate_cost_by_tokens(
|
||||
completion_tokens, model, "output"
|
||||
), completion_tokens
|
||||
|
||||
Reference in New Issue
Block a user