python compat fixes

This commit is contained in:
Kyle Corbitt
2023-08-23 17:14:19 -07:00
parent f59150ff5b
commit 50a79b6e3a
4 changed files with 25 additions and 24 deletions

View File

@@ -4,7 +4,7 @@ import time
import inspect import inspect
from openpipe.merge_openai_chunks import merge_openai_chunks from openpipe.merge_openai_chunks import merge_openai_chunks
from openpipe.openpipe_meta import OpenPipeMeta from openpipe.openpipe_meta import openpipe_meta
from .shared import ( from .shared import (
_should_check_cache, _should_check_cache,
@@ -41,9 +41,11 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
) )
cache_status = ( cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP" "MISS"
if _should_check_cache(openpipe_options, kwargs)
else "SKIP"
) )
chunk.openpipe = OpenPipeMeta(cache_status=cache_status) chunk.openpipe = openpipe_meta(cache_status=cache_status)
yield chunk yield chunk
@@ -72,9 +74,9 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
) )
cache_status = ( cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP" "MISS" if _should_check_cache(openpipe_options, kwargs) else "SKIP"
) )
chat_completion["openpipe"] = OpenPipeMeta(cache_status=cache_status) chat_completion["openpipe"] = openpipe_meta(cache_status=cache_status)
return chat_completion return chat_completion
except Exception as e: except Exception as e:
received_at = int(time.time() * 1000) received_at = int(time.time() * 1000)
@@ -126,9 +128,11 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
assembled_completion, chunk assembled_completion, chunk
) )
cache_status = ( cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP" "MISS"
if _should_check_cache(openpipe_options, kwargs)
else "SKIP"
) )
chunk.openpipe = OpenPipeMeta(cache_status=cache_status) chunk.openpipe = openpipe_meta(cache_status=cache_status)
yield chunk yield chunk
@@ -157,9 +161,9 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
) )
cache_status = ( cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP" "MISS" if _should_check_cache(openpipe_options, kwargs) else "SKIP"
) )
chat_completion["openpipe"] = OpenPipeMeta(cache_status=cache_status) chat_completion["openpipe"] = openpipe_meta(cache_status=cache_status)
return chat_completion return chat_completion
except Exception as e: except Exception as e:

View File

@@ -1,7 +1,2 @@
from attr import dataclass def openpipe_meta(cache_status: str):
return {"cache_status": cache_status}
@dataclass
class OpenPipeMeta:
# Cache status. One of 'HIT', 'MISS', 'SKIP'
cache_status: str

View File

@@ -27,12 +27,14 @@ def last_logged_call():
return local_testing_only_get_latest_logged_call.sync(client=configured_client) return local_testing_only_get_latest_logged_call.sync(client=configured_client)
@pytest.mark.focus
def test_sync(): def test_sync():
completion = openai.ChatCompletion.create( completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "count to 3"}], messages=[{"role": "system", "content": "count to 3"}],
) )
print("completion is", completion)
last_logged = last_logged_call() last_logged = last_logged_call()
assert ( assert (
last_logged.model_response.resp_payload["choices"][0]["message"]["content"] last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
@@ -42,7 +44,7 @@ def test_sync():
last_logged.model_response.req_payload["messages"][0]["content"] == "count to 3" last_logged.model_response.req_payload["messages"][0]["content"] == "count to 3"
) )
assert completion.openpipe.cache_status == "SKIP" assert completion.openpipe["cache_status"] == "SKIP"
def test_streaming(): def test_streaming():
@@ -75,7 +77,7 @@ async def test_async():
== "count down from 5" == "count down from 5"
) )
assert completion.openpipe.cache_status == "SKIP" assert completion.openpipe["cache_status"] == "SKIP"
async def test_async_streaming(): async def test_async_streaming():
@@ -87,7 +89,7 @@ async def test_async_streaming():
merged = None merged = None
async for chunk in completion: async for chunk in completion:
assert chunk.openpipe.cache_status == "SKIP" assert chunk.openpipe["cache_status"] == "SKIP"
merged = merge_openai_chunks(merged, chunk) merged = merge_openai_chunks(merged, chunk)
last_logged = last_logged_call() last_logged = last_logged_call()
@@ -100,7 +102,7 @@ async def test_async_streaming():
last_logged.model_response.req_payload["messages"][0]["content"] last_logged.model_response.req_payload["messages"][0]["content"]
== "count down from 5" == "count down from 5"
) )
assert merged["openpipe"].cache_status == "SKIP" assert merged["openpipe"]["cache_status"] == "SKIP"
def test_sync_with_tags(): def test_sync_with_tags():
@@ -146,7 +148,7 @@ async def test_caching():
messages=messages, messages=messages,
openpipe={"cache": True}, openpipe={"cache": True},
) )
assert completion.openpipe.cache_status == "MISS" assert completion.openpipe["cache_status"] == "MISS"
first_logged = last_logged_call() first_logged = last_logged_call()
assert ( assert (
@@ -159,4 +161,4 @@ async def test_caching():
messages=messages, messages=messages,
openpipe={"cache": True}, openpipe={"cache": True},
) )
assert completion2.openpipe.cache_status == "HIT" assert completion2.openpipe["cache_status"] == "HIT"

View File

@@ -1,8 +1,8 @@
[tool.poetry] [tool.poetry]
name = "openpipe" name = "openpipe"
version = "0.1.0" version = "3.0.0"
description = "" description = ""
authors = ["Kyle Corbitt <kyle@corbt.com>"] authors = ["Kyle Corbitt <kyle@openpipe.ai>"]
license = "Apache-2.0" license = "Apache-2.0"
[tool.poetry.dependencies] [tool.poetry.dependencies]