fixed unbound vals

This commit is contained in:
reibs
2024-01-03 15:16:32 -08:00
parent 5d83c87a56
commit 5d5f6dc7d1
5 changed files with 146 additions and 11 deletions

View File

@@ -18,15 +18,19 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"tiktoken",
"pyyaml"
"tiktoken>=0.5.2",
"pyyaml>=6.0.1"
]
[project.optional-dependencies]
dev = [
"pytest",
"flake8",
"coverage",
"pytest>=7.4.4",
"flake8>=3.1.0",
"coverage[toml]>=7.4.0",
]
llama-index = [
"llama-index>=0.9.24"
]
[project.urls]
Homepage = "https://github.com/AgentOps-AI/tokencost"

View File

@@ -0,0 +1,63 @@
# test_llama_index.py
import pytest
from tokencost.callbacks import llama_index
from llama_index.callbacks.schema import CBEventType, EventPayload
from unittest.mock import MagicMock
# Mock the calculate_prompt_cost and calculate_completion_cost functions
# and the USD_PER_TPU constant
STRING = "Hello, world!"
@pytest.fixture
def mock_tokencost(monkeypatch):
monkeypatch.setattr('tokencost.calculate_prompt_cost', MagicMock(return_value=100))
monkeypatch.setattr('tokencost.calculate_completion_cost', MagicMock(return_value=200))
monkeypatch.setattr('tokencost.USD_PER_TPU', 10)
# Mock the ChatMessage class
@pytest.fixture
def mock_chat_message(monkeypatch):
class MockChatMessage:
def __init__(self, text):
self.text = text
def __str__(self):
return self.text
monkeypatch.setattr('llama_index.llms.ChatMessage', MockChatMessage)
return MockChatMessage
# Test the _calc_llm_event_cost method for prompt and completion
def test_calc_llm_event_cost_prompt_completion(mock_tokencost, capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
payload = {
EventPayload.PROMPT: STRING,
EventPayload.COMPLETION: STRING
}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 6e-06" in captured.out
assert "# Completion: 8e-06" in captured.out
# Test the _calc_llm_event_cost method for messages and response
def test_calc_llm_event_cost_messages_response(mock_tokencost, mock_chat_message, capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
payload = {
EventPayload.MESSAGES: [mock_chat_message("message 1"), mock_chat_message("message 2")],
EventPayload.RESPONSE: "test response"
}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 1.05e-05" in captured.out
assert "# Completion: 4e-06" in captured.out
# Additional tests can be written for start_trace, end_trace, on_event_start, and on_event_end
# depending on the specific logic and requirements of those methods.

View File

View File

@@ -0,0 +1,65 @@
from typing import Any, Dict, List, Optional, cast
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from tokencost import calculate_prompt_cost, calculate_completion_cost, USD_PER_TPU
class TokenCostHandler(BaseCallbackHandler):
"""Callback handler for printing llms inputs/outputs."""
def __init__(self, model) -> None:
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
self.model = model
def start_trace(self, trace_id: Optional[str] = None) -> None:
return
def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
return
def _calc_llm_event_cost(self, payload: dict) -> None:
from llama_index.llms import ChatMessage
prompt_cost = 0
completion_cost = 0
if EventPayload.PROMPT in payload:
prompt = str(payload.get(EventPayload.PROMPT))
completion = str(payload.get(EventPayload.COMPLETION))
prompt_cost = calculate_prompt_cost(prompt, self.model) / USD_PER_TPU
completion_cost = calculate_completion_cost(completion, self.model) / USD_PER_TPU
elif EventPayload.MESSAGES in payload:
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
messages_str = "\n".join([str(x) for x in messages])
prompt_cost = calculate_prompt_cost(messages_str, self.model) / USD_PER_TPU
response = str(payload.get(EventPayload.RESPONSE))
completion_cost = calculate_completion_cost(response, self.model) / USD_PER_TPU
print(f"# Prompt cost: {prompt_cost}")
print(f"# Completion: {completion_cost}")
print("\n")
def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
return event_id
def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Count the LLM or Embedding tokens as needed."""
if event_type == CBEventType.LLM and payload is not None:
self._calc_llm_event_cost(payload)

15
tox.ini
View File

@@ -1,20 +1,23 @@
[tox]
envlist = py3, flake8
envlist = py3, flake8, py3-llama-index
isolated_build = true
[testenv]
deps =
pytest
coverage
commands =
coverage run --source tokencost -m pytest
coverage report -m
[testenv:flake8]
deps = flake8
commands = flake8 tokencost/
[testenv:py3-llama-index]
deps =
flake8
{[testenv]deps}
.[llama-index]
commands =
flake8 tokencost/
coverage run --source tokencost -m pytest {posargs}
coverage report -m
[flake8]
max-line-length = 120