Python package improvements

Added an endpoint for getting the actual stored responses, and used it to test and improve the python package.
This commit is contained in:
Kyle Corbitt
2023-08-14 19:07:03 -07:00
parent c4cef35717
commit 754e273049
17 changed files with 786 additions and 57 deletions

View File

@@ -0,0 +1,133 @@
from http import HTTPStatus
from typing import Any, Dict, Optional, Union
import httpx
from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.local_testing_only_get_latest_logged_call_response_200 import (
LocalTestingOnlyGetLatestLoggedCallResponse200,
)
from ...types import Response
def _get_kwargs() -> Dict[str, Any]:
pass
return {
"method": "get",
"url": "/local-testing-only-get-latest-logged-call",
}
def _parse_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
if response.status_code == HTTPStatus.OK:
_response_200 = response.json()
response_200: Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]
if _response_200 is None:
response_200 = None
else:
response_200 = LocalTestingOnlyGetLatestLoggedCallResponse200.from_dict(_response_200)
return response_200
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None
def _build_response(
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)
def sync_detailed(
*,
client: AuthenticatedClient,
) -> Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
"""Get the latest logged call (only for local testing)
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]
"""
kwargs = _get_kwargs()
response = client.get_httpx_client().request(
**kwargs,
)
return _build_response(client=client, response=response)
def sync(
*,
client: AuthenticatedClient,
) -> Optional[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
"""Get the latest logged call (only for local testing)
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]
"""
return sync_detailed(
client=client,
).parsed
async def asyncio_detailed(
*,
client: AuthenticatedClient,
) -> Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
"""Get the latest logged call (only for local testing)
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]
"""
kwargs = _get_kwargs()
response = await client.get_async_httpx_client().request(**kwargs)
return _build_response(client=client, response=response)
async def asyncio(
*,
client: AuthenticatedClient,
) -> Optional[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
"""Get the latest logged call (only for local testing)
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]
"""
return (
await asyncio_detailed(
client=client,
)
).parsed

View File

@@ -3,6 +3,13 @@
from .check_cache_json_body import CheckCacheJsonBody
from .check_cache_json_body_tags import CheckCacheJsonBodyTags
from .check_cache_response_200 import CheckCacheResponse200
from .local_testing_only_get_latest_logged_call_response_200 import LocalTestingOnlyGetLatestLoggedCallResponse200
from .local_testing_only_get_latest_logged_call_response_200_model_response import (
LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse,
)
from .local_testing_only_get_latest_logged_call_response_200_tags import (
LocalTestingOnlyGetLatestLoggedCallResponse200Tags,
)
from .report_json_body import ReportJsonBody
from .report_json_body_tags import ReportJsonBodyTags
@@ -10,6 +17,9 @@ __all__ = (
"CheckCacheJsonBody",
"CheckCacheJsonBodyTags",
"CheckCacheResponse200",
"LocalTestingOnlyGetLatestLoggedCallResponse200",
"LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse",
"LocalTestingOnlyGetLatestLoggedCallResponse200Tags",
"ReportJsonBody",
"ReportJsonBodyTags",
)

View File

@@ -0,0 +1,84 @@
import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar
from attrs import define
from dateutil.parser import isoparse
if TYPE_CHECKING:
from ..models.local_testing_only_get_latest_logged_call_response_200_model_response import (
LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse,
)
from ..models.local_testing_only_get_latest_logged_call_response_200_tags import (
LocalTestingOnlyGetLatestLoggedCallResponse200Tags,
)
T = TypeVar("T", bound="LocalTestingOnlyGetLatestLoggedCallResponse200")
@define
class LocalTestingOnlyGetLatestLoggedCallResponse200:
"""
Attributes:
created_at (datetime.datetime):
cache_hit (bool):
tags (LocalTestingOnlyGetLatestLoggedCallResponse200Tags):
model_response (Optional[LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse]):
"""
created_at: datetime.datetime
cache_hit: bool
tags: "LocalTestingOnlyGetLatestLoggedCallResponse200Tags"
model_response: Optional["LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse"]
def to_dict(self) -> Dict[str, Any]:
created_at = self.created_at.isoformat()
cache_hit = self.cache_hit
tags = self.tags.to_dict()
model_response = self.model_response.to_dict() if self.model_response else None
field_dict: Dict[str, Any] = {}
field_dict.update(
{
"createdAt": created_at,
"cacheHit": cache_hit,
"tags": tags,
"modelResponse": model_response,
}
)
return field_dict
@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
from ..models.local_testing_only_get_latest_logged_call_response_200_model_response import (
LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse,
)
from ..models.local_testing_only_get_latest_logged_call_response_200_tags import (
LocalTestingOnlyGetLatestLoggedCallResponse200Tags,
)
d = src_dict.copy()
created_at = isoparse(d.pop("createdAt"))
cache_hit = d.pop("cacheHit")
tags = LocalTestingOnlyGetLatestLoggedCallResponse200Tags.from_dict(d.pop("tags"))
_model_response = d.pop("modelResponse")
model_response: Optional[LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse]
if _model_response is None:
model_response = None
else:
model_response = LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse.from_dict(_model_response)
local_testing_only_get_latest_logged_call_response_200 = cls(
created_at=created_at,
cache_hit=cache_hit,
tags=tags,
model_response=model_response,
)
return local_testing_only_get_latest_logged_call_response_200

View File

@@ -0,0 +1,70 @@
from typing import Any, Dict, Optional, Type, TypeVar, Union
from attrs import define
from ..types import UNSET, Unset
T = TypeVar("T", bound="LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse")
@define
class LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse:
"""
Attributes:
id (str):
status_code (Optional[float]):
error_message (Optional[str]):
req_payload (Union[Unset, Any]):
resp_payload (Union[Unset, Any]):
"""
id: str
status_code: Optional[float]
error_message: Optional[str]
req_payload: Union[Unset, Any] = UNSET
resp_payload: Union[Unset, Any] = UNSET
def to_dict(self) -> Dict[str, Any]:
id = self.id
status_code = self.status_code
error_message = self.error_message
req_payload = self.req_payload
resp_payload = self.resp_payload
field_dict: Dict[str, Any] = {}
field_dict.update(
{
"id": id,
"statusCode": status_code,
"errorMessage": error_message,
}
)
if req_payload is not UNSET:
field_dict["reqPayload"] = req_payload
if resp_payload is not UNSET:
field_dict["respPayload"] = resp_payload
return field_dict
@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
id = d.pop("id")
status_code = d.pop("statusCode")
error_message = d.pop("errorMessage")
req_payload = d.pop("reqPayload", UNSET)
resp_payload = d.pop("respPayload", UNSET)
local_testing_only_get_latest_logged_call_response_200_model_response = cls(
id=id,
status_code=status_code,
error_message=error_message,
req_payload=req_payload,
resp_payload=resp_payload,
)
return local_testing_only_get_latest_logged_call_response_200_model_response

View File

@@ -0,0 +1,43 @@
from typing import Any, Dict, List, Optional, Type, TypeVar
from attrs import define, field
T = TypeVar("T", bound="LocalTestingOnlyGetLatestLoggedCallResponse200Tags")
@define
class LocalTestingOnlyGetLatestLoggedCallResponse200Tags:
""" """
additional_properties: Dict[str, Optional[str]] = field(init=False, factory=dict)
def to_dict(self) -> Dict[str, Any]:
field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
return field_dict
@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
local_testing_only_get_latest_logged_call_response_200_tags = cls()
local_testing_only_get_latest_logged_call_response_200_tags.additional_properties = d
return local_testing_only_get_latest_logged_call_response_200_tags
@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())
def __getitem__(self, key: str) -> Optional[str]:
return self.additional_properties[key]
def __setitem__(self, key: str, value: Optional[str]) -> None:
self.additional_properties[key] = value
def __delitem__(self, key: str) -> None:
del self.additional_properties[key]
def __contains__(self, key: str) -> bool:
return key in self.additional_properties

View File

@@ -1,9 +1,9 @@
from typing import Any, Optional
def merge_streamed_chunks(base: Optional[Any], chunk: Any) -> Any:
def merge_openai_chunks(base: Optional[Any], chunk: Any) -> Any:
if base is None:
return merge_streamed_chunks({**chunk, "choices": []}, chunk)
return merge_openai_chunks({**chunk, "choices": []}, chunk)
choices = base["choices"].copy()
for choice in chunk["choices"]:
@@ -34,9 +34,7 @@ def merge_streamed_chunks(base: Optional[Any], chunk: Any) -> Any:
{**new_choice, "message": {"role": "assistant", **choice["delta"]}}
)
merged = {
return {
**base,
"choices": choices,
}
return merged

View File

@@ -3,9 +3,16 @@ from openai.openai_object import OpenAIObject
import time
import inspect
from openpipe.merge_openai_chunks import merge_streamed_chunks
from openpipe.merge_openai_chunks import merge_openai_chunks
from openpipe.openpipe_meta import OpenPipeMeta
from .shared import maybe_check_cache, maybe_check_cache_async, report_async, report
from .shared import (
_should_check_cache,
maybe_check_cache,
maybe_check_cache_async,
report_async,
report,
)
class WrappedChatCompletion(original_openai.ChatCompletion):
@@ -29,9 +36,15 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
def _gen():
assembled_completion = None
for chunk in chat_completion:
assembled_completion = merge_streamed_chunks(
assembled_completion = merge_openai_chunks(
assembled_completion, chunk
)
cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
)
chunk.openpipe = OpenPipeMeta(cache_status=cache_status)
yield chunk
received_at = int(time.time() * 1000)
@@ -58,6 +71,10 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
status_code=200,
)
cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
)
chat_completion["openpipe"] = OpenPipeMeta(cache_status=cache_status)
return chat_completion
except Exception as e:
received_at = int(time.time() * 1000)
@@ -96,21 +113,28 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
requested_at = int(time.time() * 1000)
try:
chat_completion = original_openai.ChatCompletion.acreate(*args, **kwargs)
chat_completion = await original_openai.ChatCompletion.acreate(
*args, **kwargs
)
if inspect.isgenerator(chat_completion):
if inspect.isasyncgen(chat_completion):
def _gen():
async def _gen():
assembled_completion = None
for chunk in chat_completion:
assembled_completion = merge_streamed_chunks(
async for chunk in chat_completion:
assembled_completion = merge_openai_chunks(
assembled_completion, chunk
)
cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
)
chunk.openpipe = OpenPipeMeta(cache_status=cache_status)
yield chunk
received_at = int(time.time() * 1000)
report_async(
await report_async(
openpipe_options=openpipe_options,
requested_at=requested_at,
received_at=received_at,
@@ -123,7 +147,7 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
else:
received_at = int(time.time() * 1000)
report_async(
await report_async(
openpipe_options=openpipe_options,
requested_at=requested_at,
received_at=received_at,
@@ -132,12 +156,17 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
status_code=200,
)
cache_status = (
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
)
chat_completion["openpipe"] = OpenPipeMeta(cache_status=cache_status)
return chat_completion
except Exception as e:
received_at = int(time.time() * 1000)
if isinstance(e, original_openai.OpenAIError):
report_async(
await report_async(
openpipe_options=openpipe_options,
requested_at=requested_at,
received_at=received_at,
@@ -147,7 +176,7 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
status_code=e.http_status,
)
else:
report_async(
await report_async(
openpipe_options=openpipe_options,
requested_at=requested_at,
received_at=received_at,

View File

@@ -0,0 +1,7 @@
from attr import dataclass
@dataclass
class OpenPipeMeta:
# Cache status. One of 'HIT', 'MISS', 'SKIP'
cache_status: str

View File

@@ -1,5 +1,5 @@
from openpipe.api_client.api.default import (
api_report,
report as api_report,
check_cache,
)
from openpipe.api_client.client import AuthenticatedClient

View File

@@ -1,55 +1,106 @@
from functools import reduce
from dotenv import load_dotenv
from . import openai, configure_openpipe
import os
import pytest
from . import openai, configure_openpipe, configured_client
from .api_client.api.default import local_testing_only_get_latest_logged_call
from .merge_openai_chunks import merge_openai_chunks
import random
import string
def random_string(length):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(length))
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
configure_openpipe(
base_url="http://localhost:3000/api", api_key=os.getenv("OPENPIPE_API_KEY")
base_url="http://localhost:3000/api/v1", api_key=os.getenv("OPENPIPE_API_KEY")
)
def last_logged_call():
return local_testing_only_get_latest_logged_call.sync(client=configured_client)
def test_sync():
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "count to 10"}],
messages=[{"role": "system", "content": "count to 3"}],
)
print(completion.choices[0].message.content)
last_logged = last_logged_call()
assert (
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
== completion.choices[0].message.content
)
assert (
last_logged.model_response.req_payload["messages"][0]["content"] == "count to 3"
)
assert completion.openpipe.cache_status == "SKIP"
def test_streaming():
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "count to 10"}],
messages=[{"role": "system", "content": "count to 4"}],
stream=True,
)
for chunk in completion:
print(chunk)
merged = reduce(merge_openai_chunks, completion, None)
last_logged = last_logged_call()
assert (
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
== merged["choices"][0]["message"]["content"]
)
async def test_async():
acompletion = await openai.ChatCompletion.acreate(
completion = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "count down from 5"}],
)
last_logged = last_logged_call()
assert (
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
== completion.choices[0].message.content
)
assert (
last_logged.model_response.req_payload["messages"][0]["content"]
== "count down from 5"
)
print(acompletion.choices[0].message.content)
assert completion.openpipe.cache_status == "SKIP"
async def test_async_streaming():
acompletion = await openai.ChatCompletion.acreate(
completion = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "count down from 5"}],
stream=True,
)
async for chunk in acompletion:
print(chunk)
merged = None
async for chunk in completion:
assert chunk.openpipe.cache_status == "SKIP"
merged = merge_openai_chunks(merged, chunk)
last_logged = last_logged_call()
assert (
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
== merged["choices"][0]["message"]["content"]
)
assert (
last_logged.model_response.req_payload["messages"][0]["content"]
== "count down from 5"
)
assert merged["openpipe"].cache_status == "SKIP"
def test_sync_with_tags():
@@ -58,31 +109,54 @@ def test_sync_with_tags():
messages=[{"role": "system", "content": "count to 10"}],
openpipe={"tags": {"promptId": "testprompt"}},
)
print("finished")
print(completion.choices[0].message.content)
last_logged = last_logged_call()
assert (
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
== completion.choices[0].message.content
)
print(last_logged.tags)
assert last_logged.tags["promptId"] == "testprompt"
assert last_logged.tags["$sdk"] == "python"
def test_bad_call():
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-blaster",
messages=[{"role": "system", "content": "count to 10"}],
stream=True,
try:
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-blaster",
messages=[{"role": "system", "content": "count to 10"}],
stream=True,
)
assert False
except Exception as e:
pass
last_logged = last_logged_call()
print(last_logged)
assert (
last_logged.model_response.error_message
== "The model `gpt-3.5-turbo-blaster` does not exist"
)
assert last_logged.model_response.status_code == 404
@pytest.mark.focus
async def test_caching():
messages = [{"role": "system", "content": f"repeat '{random_string(10)}'"}]
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "count to 10"}],
messages=messages,
openpipe={"cache": True},
)
assert completion.openpipe.cache_status == "MISS"
first_logged = last_logged_call()
assert (
completion.choices[0].message.content
== first_logged.model_response.resp_payload["choices"][0]["message"]["content"]
)
completion2 = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "count to 10"}],
openpipe={"cache": True},
)
print(completion2)
assert completion2.openpipe.cache_status == "HIT"