mirror of
https://github.com/AgentOps-AI/tokencost.git
synced 2024-06-22 04:30:40 +03:00
fixed unbound vals
This commit is contained in:
@@ -18,15 +18,19 @@ classifiers = [
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dependencies = [
|
||||
"tiktoken",
|
||||
"pyyaml"
|
||||
"tiktoken>=0.5.2",
|
||||
"pyyaml>=6.0.1"
|
||||
]
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"flake8",
|
||||
"coverage",
|
||||
"pytest>=7.4.4",
|
||||
"flake8>=3.1.0",
|
||||
"coverage[toml]>=7.4.0",
|
||||
]
|
||||
llama-index = [
|
||||
"llama-index>=0.9.24"
|
||||
]
|
||||
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/AgentOps-AI/tokencost"
|
||||
|
||||
63
tests/test_llama_index_callbacks.py
Normal file
63
tests/test_llama_index_callbacks.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# test_llama_index.py
|
||||
import pytest
|
||||
from tokencost.callbacks import llama_index
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Mock the calculate_prompt_cost and calculate_completion_cost functions
|
||||
# and the USD_PER_TPU constant
|
||||
|
||||
STRING = "Hello, world!"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokencost(monkeypatch):
|
||||
monkeypatch.setattr('tokencost.calculate_prompt_cost', MagicMock(return_value=100))
|
||||
monkeypatch.setattr('tokencost.calculate_completion_cost', MagicMock(return_value=200))
|
||||
monkeypatch.setattr('tokencost.USD_PER_TPU', 10)
|
||||
|
||||
# Mock the ChatMessage class
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_message(monkeypatch):
|
||||
class MockChatMessage:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
monkeypatch.setattr('llama_index.llms.ChatMessage', MockChatMessage)
|
||||
return MockChatMessage
|
||||
|
||||
# Test the _calc_llm_event_cost method for prompt and completion
|
||||
|
||||
|
||||
def test_calc_llm_event_cost_prompt_completion(mock_tokencost, capsys):
|
||||
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
|
||||
payload = {
|
||||
EventPayload.PROMPT: STRING,
|
||||
EventPayload.COMPLETION: STRING
|
||||
}
|
||||
handler._calc_llm_event_cost(payload)
|
||||
captured = capsys.readouterr()
|
||||
assert "# Prompt cost: 6e-06" in captured.out
|
||||
assert "# Completion: 8e-06" in captured.out
|
||||
|
||||
# Test the _calc_llm_event_cost method for messages and response
|
||||
|
||||
|
||||
def test_calc_llm_event_cost_messages_response(mock_tokencost, mock_chat_message, capsys):
|
||||
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
|
||||
payload = {
|
||||
EventPayload.MESSAGES: [mock_chat_message("message 1"), mock_chat_message("message 2")],
|
||||
EventPayload.RESPONSE: "test response"
|
||||
}
|
||||
handler._calc_llm_event_cost(payload)
|
||||
captured = capsys.readouterr()
|
||||
assert "# Prompt cost: 1.05e-05" in captured.out
|
||||
assert "# Completion: 4e-06" in captured.out
|
||||
|
||||
# Additional tests can be written for start_trace, end_trace, on_event_start, and on_event_end
|
||||
# depending on the specific logic and requirements of those methods.
|
||||
0
tokencost/callbacks/__init__.py
Normal file
0
tokencost/callbacks/__init__.py
Normal file
65
tokencost/callbacks/llama_index.py
Normal file
65
tokencost/callbacks/llama_index.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from llama_index.callbacks.base_handler import BaseCallbackHandler
|
||||
from llama_index.callbacks.schema import CBEventType, EventPayload
|
||||
from tokencost import calculate_prompt_cost, calculate_completion_cost, USD_PER_TPU
|
||||
|
||||
|
||||
class TokenCostHandler(BaseCallbackHandler):
|
||||
"""Callback handler for printing llms inputs/outputs."""
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
|
||||
self.model = model
|
||||
|
||||
def start_trace(self, trace_id: Optional[str] = None) -> None:
|
||||
return
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
trace_map: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def _calc_llm_event_cost(self, payload: dict) -> None:
|
||||
from llama_index.llms import ChatMessage
|
||||
|
||||
prompt_cost = 0
|
||||
completion_cost = 0
|
||||
if EventPayload.PROMPT in payload:
|
||||
prompt = str(payload.get(EventPayload.PROMPT))
|
||||
completion = str(payload.get(EventPayload.COMPLETION))
|
||||
prompt_cost = calculate_prompt_cost(prompt, self.model) / USD_PER_TPU
|
||||
completion_cost = calculate_completion_cost(completion, self.model) / USD_PER_TPU
|
||||
|
||||
elif EventPayload.MESSAGES in payload:
|
||||
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
|
||||
messages_str = "\n".join([str(x) for x in messages])
|
||||
prompt_cost = calculate_prompt_cost(messages_str, self.model) / USD_PER_TPU
|
||||
response = str(payload.get(EventPayload.RESPONSE))
|
||||
completion_cost = calculate_completion_cost(response, self.model) / USD_PER_TPU
|
||||
|
||||
print(f"# Prompt cost: {prompt_cost}")
|
||||
print(f"# Completion: {completion_cost}")
|
||||
print("\n")
|
||||
|
||||
def on_event_start(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
parent_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return event_id
|
||||
|
||||
def on_event_end(
|
||||
self,
|
||||
event_type: CBEventType,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
event_id: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Count the LLM or Embedding tokens as needed."""
|
||||
if event_type == CBEventType.LLM and payload is not None:
|
||||
self._calc_llm_event_cost(payload)
|
||||
15
tox.ini
15
tox.ini
@@ -1,20 +1,23 @@
|
||||
[tox]
|
||||
envlist = py3, flake8
|
||||
envlist = py3, flake8, py3-llama-index
|
||||
isolated_build = true
|
||||
|
||||
[testenv]
|
||||
deps =
|
||||
pytest
|
||||
coverage
|
||||
commands =
|
||||
coverage run --source tokencost -m pytest
|
||||
coverage report -m
|
||||
|
||||
[testenv:flake8]
|
||||
deps = flake8
|
||||
commands = flake8 tokencost/
|
||||
|
||||
[testenv:py3-llama-index]
|
||||
deps =
|
||||
flake8
|
||||
{[testenv]deps}
|
||||
.[llama-index]
|
||||
commands =
|
||||
flake8 tokencost/
|
||||
coverage run --source tokencost -m pytest {posargs}
|
||||
coverage report -m
|
||||
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
|
||||
Reference in New Issue
Block a user