diff --git a/pyproject.toml b/pyproject.toml index 5058bb5..314294b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,15 +18,19 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "tiktoken", - "pyyaml" + "tiktoken>=0.5.2", + "pyyaml>=6.0.1" ] [project.optional-dependencies] dev = [ - "pytest", - "flake8", - "coverage", + "pytest>=7.4.4", + "flake8>=3.1.0", + "coverage[toml]>=7.4.0", ] +llama-index = [ + "llama-index>=0.9.24" +] + [project.urls] Homepage = "https://github.com/AgentOps-AI/tokencost" diff --git a/tests/test_llama_index_callbacks.py b/tests/test_llama_index_callbacks.py new file mode 100644 index 0000000..6541a08 --- /dev/null +++ b/tests/test_llama_index_callbacks.py @@ -0,0 +1,63 @@ +# test_llama_index.py +import pytest +from tokencost.callbacks import llama_index +from llama_index.callbacks.schema import CBEventType, EventPayload +from unittest.mock import MagicMock + +# Mock the calculate_prompt_cost and calculate_completion_cost functions +# and the USD_PER_TPU constant + +STRING = "Hello, world!" + + +@pytest.fixture +def mock_tokencost(monkeypatch): + monkeypatch.setattr('tokencost.calculate_prompt_cost', MagicMock(return_value=100)) + monkeypatch.setattr('tokencost.calculate_completion_cost', MagicMock(return_value=200)) + monkeypatch.setattr('tokencost.USD_PER_TPU', 10) + +# Mock the ChatMessage class + + +@pytest.fixture +def mock_chat_message(monkeypatch): + class MockChatMessage: + def __init__(self, text): + self.text = text + + def __str__(self): + return self.text + + monkeypatch.setattr('llama_index.llms.ChatMessage', MockChatMessage) + return MockChatMessage + +# Test the _calc_llm_event_cost method for prompt and completion + + +def test_calc_llm_event_cost_prompt_completion(mock_tokencost, capsys): + handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo') + payload = { + EventPayload.PROMPT: STRING, + EventPayload.COMPLETION: STRING + } + handler._calc_llm_event_cost(payload) + captured = capsys.readouterr() + assert "# Prompt cost: 6e-06" in captured.out + assert "# Completion: 8e-06" in captured.out + +# Test the _calc_llm_event_cost method for messages and response + + +def test_calc_llm_event_cost_messages_response(mock_tokencost, mock_chat_message, capsys): + handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo') + payload = { + EventPayload.MESSAGES: [mock_chat_message("message 1"), mock_chat_message("message 2")], + EventPayload.RESPONSE: "test response" + } + handler._calc_llm_event_cost(payload) + captured = capsys.readouterr() + assert "# Prompt cost: 1.05e-05" in captured.out + assert "# Completion: 4e-06" in captured.out + +# Additional tests can be written for start_trace, end_trace, on_event_start, and on_event_end +# depending on the specific logic and requirements of those methods. diff --git a/tokencost/callbacks/__init__.py b/tokencost/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tokencost/callbacks/llama_index.py b/tokencost/callbacks/llama_index.py new file mode 100644 index 0000000..0df5b69 --- /dev/null +++ b/tokencost/callbacks/llama_index.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, List, Optional, cast +from llama_index.callbacks.base_handler import BaseCallbackHandler +from llama_index.callbacks.schema import CBEventType, EventPayload +from tokencost import calculate_prompt_cost, calculate_completion_cost, USD_PER_TPU + + +class TokenCostHandler(BaseCallbackHandler): + """Callback handler for printing llms inputs/outputs.""" + + def __init__(self, model) -> None: + super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[]) + self.model = model + + def start_trace(self, trace_id: Optional[str] = None) -> None: + return + + def end_trace( + self, + trace_id: Optional[str] = None, + trace_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + return + + def _calc_llm_event_cost(self, payload: dict) -> None: + from llama_index.llms import ChatMessage + + prompt_cost = 0 + completion_cost = 0 + if EventPayload.PROMPT in payload: + prompt = str(payload.get(EventPayload.PROMPT)) + completion = str(payload.get(EventPayload.COMPLETION)) + prompt_cost = calculate_prompt_cost(prompt, self.model) / USD_PER_TPU + completion_cost = calculate_completion_cost(completion, self.model) / USD_PER_TPU + + elif EventPayload.MESSAGES in payload: + messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, [])) + messages_str = "\n".join([str(x) for x in messages]) + prompt_cost = calculate_prompt_cost(messages_str, self.model) / USD_PER_TPU + response = str(payload.get(EventPayload.RESPONSE)) + completion_cost = calculate_completion_cost(response, self.model) / USD_PER_TPU + + print(f"# Prompt cost: {prompt_cost}") + print(f"# Completion: {completion_cost}") + print("\n") + + def on_event_start( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + parent_id: str = "", + **kwargs: Any, + ) -> str: + return event_id + + def on_event_end( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> None: + """Count the LLM or Embedding tokens as needed.""" + if event_type == CBEventType.LLM and payload is not None: + self._calc_llm_event_cost(payload) diff --git a/tox.ini b/tox.ini index 4a1db7a..48d7815 100644 --- a/tox.ini +++ b/tox.ini @@ -1,20 +1,23 @@ [tox] -envlist = py3, flake8 +envlist = py3, flake8, py3-llama-index isolated_build = true [testenv] deps = pytest coverage -commands = - coverage run --source tokencost -m pytest - coverage report -m [testenv:flake8] +deps = flake8 +commands = flake8 tokencost/ + +[testenv:py3-llama-index] deps = - flake8 + {[testenv]deps} + .[llama-index] commands = - flake8 tokencost/ + coverage run --source tokencost -m pytest {posargs} + coverage report -m [flake8] max-line-length = 120