added calculate_cost_by_tokens - the case when token count is already available

This commit is contained in:
Nikolay Petrov
2024-02-01 17:42:16 +07:00
parent 7959de29f8
commit 7d0595a692
3 changed files with 49 additions and 15 deletions

View File

@@ -301,13 +301,13 @@ pip install -e .
0. Install `pytest` if you don't have it already
```python
```shell
pip install pytest
```
1. Run the `tests/` folder while in the parent directory
```python
```shell
pytest tests
```

View File

@@ -6,6 +6,7 @@ from decimal import Decimal
from tokencost.costs import (
count_message_tokens,
count_string_tokens,
calculate_cost_by_tokens,
calculate_prompt_cost,
calculate_completion_cost,
)
@@ -204,6 +205,20 @@ def test_calculate_invalid_input_types():
with pytest.raises(KeyError):
calculate_completion_cost(STRING, model="invalid_model")
with pytest.raises(KeyError):
with pytest.raises(TypeError):
# Message objects not allowed, must be list of message objects.
calculate_prompt_cost(MESSAGES[0], model="invalid_model")
@pytest.mark.parametrize(
"num_tokens,model,token_type,expected_output",
[
(10, "gpt-3.5-turbo", 'input', Decimal('0.0000150')), # Example values
(5, "gpt-4", 'output', Decimal('0.00030')), # Example values
(10, "ai21.j2-mid-v1", 'input', Decimal('0.0001250')), # Example values
],
)
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):
"""Test that the token cost calculation is correct."""
cost = calculate_cost_by_tokens(num_tokens, model, token_type)
assert cost == expected_output

View File

@@ -91,6 +91,31 @@ def count_string_tokens(prompt: str, model: str) -> int:
return len(encoding.encode(prompt))
def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> Decimal:
"""
Calculate the cost based on the number of tokens and the model.
Args:
num_tokens (int): The number of tokens.
model (str): The model name.
token_type (str): Type of token ('input' or 'output').
Returns:
Decimal: The calculated cost in USD.
"""
model = model.lower()
if model not in TOKEN_COSTS:
raise KeyError(
f"""Model {model} is not implemented.
Double-check your spelling, or submit an issue/PR"""
)
cost_per_token_key = 'input_cost_per_token' if token_type == 'input' else 'output_cost_per_token'
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
return Decimal(str(cost_per_token)) * Decimal(num_tokens)
def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal:
"""
Calculate the prompt's cost in USD.
@@ -113,15 +138,10 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
Decimal('0.0000030')
"""
model = model.lower()
if model not in TOKEN_COSTS:
raise KeyError(
f"""Model {model} is not implemented.
Double-check your spelling, or submit an issue/PR"""
)
if not isinstance(prompt, (list, str)) or not isinstance(prompt, (list, str)):
if not isinstance(prompt, (list, str)):
raise TypeError(
f"""Prompt and completion each must be either a string or list of message objects.
They are {type(prompt)} and {type(prompt)}, respectively.
f"""Prompt must be either a string or list of message objects.
it is {type(prompt)} instead.
"""
)
prompt_tokens = (
@@ -129,8 +149,8 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
if isinstance(prompt, str)
else count_message_tokens(prompt, model)
)
prompt_cost = TOKEN_COSTS[model]["input_cost_per_token"]
return Decimal(str(prompt_cost)) * Decimal(prompt_tokens)
return calculate_cost_by_tokens(prompt_tokens, model, 'input')
def calculate_completion_cost(completion: str, model: str) -> Decimal:
@@ -155,6 +175,5 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
Double-check your spelling, or submit an issue/PR"""
)
completion_tokens = count_string_tokens(completion, model)
completion_cost = TOKEN_COSTS[model]["output_cost_per_token"]
return Decimal(str(completion_cost)) * Decimal(completion_tokens)
return calculate_cost_by_tokens(completion_tokens, model, 'output')