mirror of
https://github.com/AgentOps-AI/tokencost.git
synced 2024-06-22 04:30:40 +03:00
added calculate_cost_by_tokens - the case when token count is already available
This commit is contained in:
@@ -301,13 +301,13 @@ pip install -e .
|
||||
|
||||
0. Install `pytest` if you don't have it already
|
||||
|
||||
```python
|
||||
```shell
|
||||
pip install pytest
|
||||
```
|
||||
|
||||
1. Run the `tests/` folder while in the parent directory
|
||||
|
||||
```python
|
||||
```shell
|
||||
pytest tests
|
||||
```
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from decimal import Decimal
|
||||
from tokencost.costs import (
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
calculate_cost_by_tokens,
|
||||
calculate_prompt_cost,
|
||||
calculate_completion_cost,
|
||||
)
|
||||
@@ -204,6 +205,20 @@ def test_calculate_invalid_input_types():
|
||||
with pytest.raises(KeyError):
|
||||
calculate_completion_cost(STRING, model="invalid_model")
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
with pytest.raises(TypeError):
|
||||
# Message objects not allowed, must be list of message objects.
|
||||
calculate_prompt_cost(MESSAGES[0], model="invalid_model")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,model,token_type,expected_output",
|
||||
[
|
||||
(10, "gpt-3.5-turbo", 'input', Decimal('0.0000150')), # Example values
|
||||
(5, "gpt-4", 'output', Decimal('0.00030')), # Example values
|
||||
(10, "ai21.j2-mid-v1", 'input', Decimal('0.0001250')), # Example values
|
||||
],
|
||||
)
|
||||
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):
|
||||
"""Test that the token cost calculation is correct."""
|
||||
cost = calculate_cost_by_tokens(num_tokens, model, token_type)
|
||||
assert cost == expected_output
|
||||
|
||||
@@ -91,6 +91,31 @@ def count_string_tokens(prompt: str, model: str) -> int:
|
||||
return len(encoding.encode(prompt))
|
||||
|
||||
|
||||
def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> Decimal:
|
||||
"""
|
||||
Calculate the cost based on the number of tokens and the model.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens.
|
||||
model (str): The model name.
|
||||
token_type (str): Type of token ('input' or 'output').
|
||||
|
||||
Returns:
|
||||
Decimal: The calculated cost in USD.
|
||||
"""
|
||||
model = model.lower()
|
||||
if model not in TOKEN_COSTS:
|
||||
raise KeyError(
|
||||
f"""Model {model} is not implemented.
|
||||
Double-check your spelling, or submit an issue/PR"""
|
||||
)
|
||||
|
||||
cost_per_token_key = 'input_cost_per_token' if token_type == 'input' else 'output_cost_per_token'
|
||||
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]
|
||||
|
||||
return Decimal(str(cost_per_token)) * Decimal(num_tokens)
|
||||
|
||||
|
||||
def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal:
|
||||
"""
|
||||
Calculate the prompt's cost in USD.
|
||||
@@ -113,15 +138,10 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
|
||||
Decimal('0.0000030')
|
||||
"""
|
||||
model = model.lower()
|
||||
if model not in TOKEN_COSTS:
|
||||
raise KeyError(
|
||||
f"""Model {model} is not implemented.
|
||||
Double-check your spelling, or submit an issue/PR"""
|
||||
)
|
||||
if not isinstance(prompt, (list, str)) or not isinstance(prompt, (list, str)):
|
||||
if not isinstance(prompt, (list, str)):
|
||||
raise TypeError(
|
||||
f"""Prompt and completion each must be either a string or list of message objects.
|
||||
They are {type(prompt)} and {type(prompt)}, respectively.
|
||||
f"""Prompt must be either a string or list of message objects.
|
||||
it is {type(prompt)} instead.
|
||||
"""
|
||||
)
|
||||
prompt_tokens = (
|
||||
@@ -129,8 +149,8 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
|
||||
if isinstance(prompt, str)
|
||||
else count_message_tokens(prompt, model)
|
||||
)
|
||||
prompt_cost = TOKEN_COSTS[model]["input_cost_per_token"]
|
||||
return Decimal(str(prompt_cost)) * Decimal(prompt_tokens)
|
||||
|
||||
return calculate_cost_by_tokens(prompt_tokens, model, 'input')
|
||||
|
||||
|
||||
def calculate_completion_cost(completion: str, model: str) -> Decimal:
|
||||
@@ -155,6 +175,5 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
|
||||
Double-check your spelling, or submit an issue/PR"""
|
||||
)
|
||||
completion_tokens = count_string_tokens(completion, model)
|
||||
completion_cost = TOKEN_COSTS[model]["output_cost_per_token"]
|
||||
|
||||
return Decimal(str(completion_cost)) * Decimal(completion_tokens)
|
||||
return calculate_cost_by_tokens(completion_tokens, model, 'output')
|
||||
|
||||
Reference in New Issue
Block a user