Add caching in Python

Still need it in JS
This commit is contained in:
Kyle Corbitt
2023-08-11 19:02:35 -07:00
parent 8ed47eb4dd
commit d7cff0f52e
5 changed files with 115 additions and 28 deletions

View File

@@ -99,6 +99,14 @@ function TableRow({
[loggedCall.tags], [loggedCall.tags],
); );
const durationCell = (
<Td isNumeric>
{loggedCall.cacheHit
? "Cache hit"
: ((loggedCall.modelResponse?.durationMs ?? 0) / 1000).toFixed(2) + "s"}
</Td>
);
return ( return (
<> <>
<Tr <Tr
@@ -120,7 +128,7 @@ function TableRow({
</Tooltip> </Tooltip>
</Td> </Td>
<Td width="100%">{model}</Td> <Td width="100%">{model}</Td>
<Td isNumeric>{((loggedCall.modelResponse?.durationMs ?? 0) / 1000).toFixed(2)}s</Td> {durationCell}
<Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td> <Td isNumeric>{loggedCall.modelResponse?.inputTokens}</Td>
<Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td> <Td isNumeric>{loggedCall.modelResponse?.outputTokens}</Td>
<Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric> <Td sx={{ color: isError ? "red.500" : "green.500", fontWeight: "semibold" }} isNumeric>

View File

@@ -70,15 +70,9 @@ export const externalApiRouter = createTRPCRouter({
const cacheKey = hashRequest(key.projectId, reqPayload as JsonValue); const cacheKey = hashRequest(key.projectId, reqPayload as JsonValue);
const existingResponse = await prisma.loggedCallModelResponse.findFirst({ const existingResponse = await prisma.loggedCallModelResponse.findFirst({
where: { where: { cacheKey },
cacheKey, include: { originalLoggedCall: true },
}, orderBy: { requestedAt: "desc" },
include: {
originalLoggedCall: true,
},
orderBy: {
requestedAt: "desc",
},
}); });
if (!existingResponse) return { respPayload: null }; if (!existingResponse) return { respPayload: null };

View File

@@ -1,23 +1,24 @@
import openai as original_openai import openai as original_openai
from openai.openai_object import OpenAIObject
import time import time
import inspect import inspect
from openpipe.merge_openai_chunks import merge_streamed_chunks from openpipe.merge_openai_chunks import merge_streamed_chunks
from .shared import report_async, report from .shared import maybe_check_cache, maybe_check_cache_async, report_async, report
class ChatCompletionWrapper: class WrappedChatCompletion(original_openai.ChatCompletion):
def __getattr__(self, name):
return getattr(original_openai.ChatCompletion, name)
def __setattr__(self, name, value):
return setattr(original_openai.ChatCompletion, name, value)
@classmethod @classmethod
def create(cls, *args, **kwargs): def create(cls, *args, **kwargs):
openpipe_options = kwargs.pop("openpipe", {}) openpipe_options = kwargs.pop("openpipe", {})
cached_response = maybe_check_cache(
openpipe_options=openpipe_options, req_payload=kwargs
)
if cached_response:
return OpenAIObject.construct_from(cached_response, api_key=None)
requested_at = int(time.time() * 1000) requested_at = int(time.time() * 1000)
try: try:
@@ -86,6 +87,12 @@ class ChatCompletionWrapper:
async def acreate(cls, *args, **kwargs): async def acreate(cls, *args, **kwargs):
openpipe_options = kwargs.pop("openpipe", {}) openpipe_options = kwargs.pop("openpipe", {})
cached_response = await maybe_check_cache_async(
openpipe_options=openpipe_options, req_payload=kwargs
)
if cached_response:
return OpenAIObject.construct_from(cached_response, api_key=None)
requested_at = int(time.time() * 1000) requested_at = int(time.time() * 1000)
try: try:
@@ -152,13 +159,10 @@ class ChatCompletionWrapper:
class OpenAIWrapper: class OpenAIWrapper:
ChatCompletion = ChatCompletionWrapper() ChatCompletion = WrappedChatCompletion()
def __getattr__(self, name): def __getattr__(self, name):
return getattr(original_openai, name) return getattr(original_openai, name)
def __setattr__(self, name, value): def __setattr__(self, name, value):
return setattr(original_openai, name, value) return setattr(original_openai, name, value)
def __dir__(self):
return dir(original_openai) + ["openpipe_base_url", "openpipe_api_key"]

View File

@@ -1,9 +1,13 @@
from openpipe.api_client.api.default import external_api_report from openpipe.api_client.api.default import (
external_api_report,
external_api_check_cache,
)
from openpipe.api_client.client import AuthenticatedClient from openpipe.api_client.client import AuthenticatedClient
from openpipe.api_client.models.external_api_report_json_body_tags import ( from openpipe.api_client.models.external_api_report_json_body_tags import (
ExternalApiReportJsonBodyTags, ExternalApiReportJsonBodyTags,
) )
import toml import toml
import time
version = toml.load("pyproject.toml")["tool"]["poetry"]["version"] version = toml.load("pyproject.toml")["tool"]["poetry"]["version"]
@@ -20,6 +24,71 @@ def _get_tags(openpipe_options):
return ExternalApiReportJsonBodyTags.from_dict(tags) return ExternalApiReportJsonBodyTags.from_dict(tags)
def _should_check_cache(openpipe_options):
if configured_client.token == "":
return False
return openpipe_options.get("cache", False)
def _process_cache_payload(
payload: external_api_check_cache.ExternalApiCheckCacheResponse200,
):
if not payload or not payload.resp_payload:
return None
payload.resp_payload["openpipe"] = {"cache_status": "HIT"}
return payload.resp_payload
def maybe_check_cache(
openpipe_options={},
req_payload={},
):
if not _should_check_cache(openpipe_options):
return None
try:
payload = external_api_check_cache.sync(
client=configured_client,
json_body=external_api_check_cache.ExternalApiCheckCacheJsonBody(
req_payload=req_payload,
requested_at=int(time.time() * 1000),
tags=_get_tags(openpipe_options),
),
)
return _process_cache_payload(payload)
except Exception as e:
# We don't want to break client apps if our API is down for some reason
print(f"Error reporting to OpenPipe: {e}")
print(e)
return None
async def maybe_check_cache_async(
openpipe_options={},
req_payload={},
):
if not _should_check_cache(openpipe_options):
return None
try:
payload = await external_api_check_cache.asyncio(
client=configured_client,
json_body=external_api_check_cache.ExternalApiCheckCacheJsonBody(
req_payload=req_payload,
requested_at=int(time.time() * 1000),
tags=_get_tags(openpipe_options),
),
)
return _process_cache_payload(payload)
except Exception as e:
# We don't want to break client apps if our API is down for some reason
print(f"Error reporting to OpenPipe: {e}")
print(e)
return None
def report( def report(
openpipe_options={}, openpipe_options={},
**kwargs, **kwargs,

View File

@@ -12,7 +12,6 @@ configure_openpipe(
) )
@pytest.mark.skip
def test_sync(): def test_sync():
completion = openai.ChatCompletion.create( completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@@ -22,7 +21,6 @@ def test_sync():
print(completion.choices[0].message.content) print(completion.choices[0].message.content)
@pytest.mark.skip
def test_streaming(): def test_streaming():
completion = openai.ChatCompletion.create( completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@@ -34,7 +32,6 @@ def test_streaming():
print(chunk) print(chunk)
@pytest.mark.skip
async def test_async(): async def test_async():
acompletion = await openai.ChatCompletion.acreate( acompletion = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@@ -44,7 +41,6 @@ async def test_async():
print(acompletion.choices[0].message.content) print(acompletion.choices[0].message.content)
@pytest.mark.skip
async def test_async_streaming(): async def test_async_streaming():
acompletion = await openai.ChatCompletion.acreate( acompletion = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@@ -67,10 +63,26 @@ def test_sync_with_tags():
print(completion.choices[0].message.content) print(completion.choices[0].message.content)
@pytest.mark.focus
def test_bad_call(): def test_bad_call():
completion = openai.ChatCompletion.create( completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-blaster", model="gpt-3.5-turbo-blaster",
messages=[{"role": "system", "content": "count to 10"}], messages=[{"role": "system", "content": "count to 10"}],
stream=True, stream=True,
) )
@pytest.mark.focus
async def test_caching():
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "count to 10"}],
openpipe={"cache": True},
)
completion2 = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "count to 10"}],
openpipe={"cache": True},
)
print(completion2)