test: Fix issues with unit tests

This commit is contained in:
André Ferreira
2024-05-28 12:09:37 +02:00
parent 691275e39d
commit 480de19113
3 changed files with 51 additions and 44 deletions

View File

@@ -131,27 +131,27 @@ def test_count_string_invalid_model():
@pytest.mark.parametrize(
"prompt,model,expected_output",
[
(MESSAGES, "gpt-3.5-turbo", Decimal('0.0000225')),
(MESSAGES, "gpt-3.5-turbo-0301", Decimal('0.0000255')),
(MESSAGES, "gpt-3.5-turbo-0613", Decimal('0.0000225')),
(MESSAGES, "gpt-3.5-turbo-16k", Decimal('0.000045')),
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal('0.000045')),
(MESSAGES, "gpt-3.5-turbo-1106", Decimal('0.000015')),
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal('0.0000225')),
(MESSAGES, "gpt-4", Decimal('0.00045')),
(MESSAGES, "gpt-4-0314", Decimal('0.00045')),
(MESSAGES, "gpt-4-32k", Decimal('0.00090')),
(MESSAGES, "gpt-4-32k-0314", Decimal('0.00090')),
(MESSAGES, "gpt-4-0613", Decimal('0.00045')),
(MESSAGES, "gpt-4-1106-preview", Decimal('0.00015')),
(MESSAGES, "gpt-4-vision-preview", Decimal('0.00015')),
(STRING, "text-embedding-ada-002", Decimal('0.0000004')),
(MESSAGES, "gpt-3.5-turbo", Decimal("0.0000225")),
(MESSAGES, "gpt-3.5-turbo-0301", Decimal("0.0000255")),
(MESSAGES, "gpt-3.5-turbo-0613", Decimal("0.0000225")),
(MESSAGES, "gpt-3.5-turbo-16k", Decimal("0.000045")),
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal("0.000045")),
(MESSAGES, "gpt-3.5-turbo-1106", Decimal("0.000015")),
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal("0.0000225")),
(MESSAGES, "gpt-4", Decimal("0.00045")),
(MESSAGES, "gpt-4-0314", Decimal("0.00045")),
(MESSAGES, "gpt-4-32k", Decimal("0.00090")),
(MESSAGES, "gpt-4-32k-0314", Decimal("0.00090")),
(MESSAGES, "gpt-4-0613", Decimal("0.00045")),
(MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")),
(MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")),
(STRING, "text-embedding-ada-002", Decimal("0.0000004")),
],
)
def test_calculate_prompt_cost(prompt, model, expected_output):
"""Test that the cost calculation is correct."""
cost = calculate_prompt_cost(prompt, model)
cost, _ = calculate_prompt_cost(prompt, model)
assert cost == expected_output
@@ -165,27 +165,27 @@ def test_invalid_prompt_format():
@pytest.mark.parametrize(
"prompt,model,expected_output",
[
(STRING, "gpt-3.5-turbo", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-0301", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-0613", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-16k", Decimal('0.000016')),
(STRING, "gpt-3.5-turbo-16k-0613", Decimal('0.000016')),
(STRING, "gpt-3.5-turbo-1106", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-instruct", Decimal('0.000008')),
(STRING, "gpt-4", Decimal('0.00024')),
(STRING, "gpt-4-0314", Decimal('0.00024')),
(STRING, "gpt-4-32k", Decimal('0.00048')),
(STRING, "gpt-4-32k-0314", Decimal('0.00048')),
(STRING, "gpt-4-0613", Decimal('0.00024')),
(STRING, "gpt-4-1106-preview", Decimal('0.00012')),
(STRING, "gpt-4-vision-preview", Decimal('0.00012')),
(STRING, "gpt-3.5-turbo", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-0301", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-0613", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-16k", Decimal("0.000016")),
(STRING, "gpt-3.5-turbo-16k-0613", Decimal("0.000016")),
(STRING, "gpt-3.5-turbo-1106", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-instruct", Decimal("0.000008")),
(STRING, "gpt-4", Decimal("0.00024")),
(STRING, "gpt-4-0314", Decimal("0.00024")),
(STRING, "gpt-4-32k", Decimal("0.00048")),
(STRING, "gpt-4-32k-0314", Decimal("0.00048")),
(STRING, "gpt-4-0613", Decimal("0.00024")),
(STRING, "gpt-4-1106-preview", Decimal("0.00012")),
(STRING, "gpt-4-vision-preview", Decimal("0.00012")),
(STRING, "text-embedding-ada-002", 0),
],
)
def test_calculate_completion_cost(prompt, model, expected_output):
"""Test that the completion cost calculation is correct."""
cost = calculate_completion_cost(prompt, model)
cost, _ = calculate_completion_cost(prompt, model)
assert cost == expected_output
@@ -213,9 +213,9 @@ def test_calculate_invalid_input_types():
@pytest.mark.parametrize(
"num_tokens,model,token_type,expected_output",
[
(10, "gpt-3.5-turbo", 'input', Decimal('0.0000150')), # Example values
(5, "gpt-4", 'output', Decimal('0.00030')), # Example values
(10, "ai21.j2-mid-v1", 'input', Decimal('0.0001250')), # Example values
(10, "gpt-3.5-turbo", "input", Decimal("0.0000150")), # Example values
(5, "gpt-4", "output", Decimal("0.00030")), # Example values
(10, "ai21.j2-mid-v1", "input", Decimal("0.0001250")), # Example values
],
)
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):

View File

@@ -1,7 +1,7 @@
# test_llama_index.py
import pytest
from tokencost.callbacks import llama_index
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from unittest.mock import MagicMock
# Mock the calculate_prompt_cost and calculate_completion_cost functions
@@ -20,36 +20,39 @@ def mock_chat_message(monkeypatch):
def __str__(self):
return self.text
monkeypatch.setattr('llama_index.llms.ChatMessage', MockChatMessage)
monkeypatch.setattr("llama_index.core.llms.ChatMessage", MockChatMessage)
return MockChatMessage
# Test the _calc_llm_event_cost method for prompt and completion
def test_calc_llm_event_cost_prompt_completion(capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
payload = {
EventPayload.PROMPT: STRING,
EventPayload.COMPLETION: STRING
}
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
payload = {EventPayload.PROMPT: STRING, EventPayload.COMPLETION: STRING}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 0.0000060" in captured.out
assert "# Completion: 0.000008" in captured.out
# Test the _calc_llm_event_cost method for messages and response
def test_calc_llm_event_cost_messages_response(mock_chat_message, capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
payload = {
EventPayload.MESSAGES: [mock_chat_message("message 1"), mock_chat_message("message 2")],
EventPayload.RESPONSE: "test response"
EventPayload.MESSAGES: [
mock_chat_message("message 1"),
mock_chat_message("message 2"),
],
EventPayload.RESPONSE: "test response",
}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 0.0000105" in captured.out
assert "# Completion: 0.000004" in captured.out
# Additional tests can be written for start_trace, end_trace, on_event_start, and on_event_end
# depending on the specific logic and requirements of those methods.

View File

@@ -49,6 +49,10 @@ class TokenCostHandler(BaseCallbackHandler):
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
print(f"# Prompt cost: {prompt_cost}")
print(f"# Completion: {completion_cost}")
print("\n")
def reset_counts(self) -> None:
self.prompt_cost = 0
self.completion_cost = 0