Python package improvements
Added an endpoint for getting the actual stored responses, and used it to test and improve the python package.
This commit is contained in:
@@ -32,5 +32,5 @@ NEXT_PUBLIC_HOST="http://localhost:3000"
|
||||
GITHUB_CLIENT_ID="your_client_id"
|
||||
GITHUB_CLIENT_SECRET="your_secret"
|
||||
|
||||
OPENPIPE_BASE_URL="http://localhost:3000/api"
|
||||
OPENPIPE_BASE_URL="http://localhost:3000/api/v1"
|
||||
OPENPIPE_API_KEY="your_key"
|
||||
|
||||
@@ -2,6 +2,7 @@ import { prisma } from "~/server/db";
|
||||
import dedent from "dedent";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
import { promptConstructorVersion } from "~/promptConstructor/version";
|
||||
import { env } from "~/env.mjs";
|
||||
|
||||
const defaultId = "11111111-1111-1111-1111-111111111111";
|
||||
|
||||
@@ -16,6 +17,16 @@ const project =
|
||||
data: { id: defaultId },
|
||||
}));
|
||||
|
||||
if (env.OPENPIPE_API_KEY) {
|
||||
await prisma.apiKey.create({
|
||||
data: {
|
||||
projectId: project.id,
|
||||
name: "Default API Key",
|
||||
apiKey: env.OPENPIPE_API_KEY,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
await prisma.experiment.deleteMany({
|
||||
where: {
|
||||
id: defaultId,
|
||||
|
||||
86
app/src/server/api/external/v1Api.router.ts
vendored
86
app/src/server/api/external/v1Api.router.ts
vendored
@@ -45,7 +45,8 @@ export const v1ApiRouter = createOpenApiRouter({
|
||||
.optional()
|
||||
.describe(
|
||||
'Extra tags to attach to the call for filtering. Eg { "userId": "123", "promptId": "populate-title" }',
|
||||
),
|
||||
)
|
||||
.default({}),
|
||||
}),
|
||||
)
|
||||
.output(
|
||||
@@ -74,6 +75,7 @@ export const v1ApiRouter = createOpenApiRouter({
|
||||
},
|
||||
});
|
||||
|
||||
await createTags(existingResponse.originalLoggedCallId, input.tags);
|
||||
return {
|
||||
respPayload: existingResponse.respPayload,
|
||||
};
|
||||
@@ -101,16 +103,16 @@ export const v1ApiRouter = createOpenApiRouter({
|
||||
.optional()
|
||||
.describe(
|
||||
'Extra tags to attach to the call for filtering. Eg { "userId": "123", "promptId": "populate-title" }',
|
||||
),
|
||||
)
|
||||
.default({}),
|
||||
}),
|
||||
)
|
||||
.output(z.void())
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
console.log("GOT TAGS", input.tags);
|
||||
const reqPayload = await reqValidator.spa(input.reqPayload);
|
||||
const respPayload = await respValidator.spa(input.respPayload);
|
||||
|
||||
const requestHash = hashRequest(ctx.key.project.id, reqPayload as JsonValue);
|
||||
const requestHash = hashRequest(ctx.key.projectId, reqPayload as JsonValue);
|
||||
|
||||
const newLoggedCallId = uuidv4();
|
||||
const newModelResponseId = uuidv4();
|
||||
@@ -129,7 +131,7 @@ export const v1ApiRouter = createOpenApiRouter({
|
||||
prisma.loggedCall.create({
|
||||
data: {
|
||||
id: newLoggedCallId,
|
||||
projectId: ctx.key.project.id,
|
||||
projectId: ctx.key.projectId,
|
||||
requestedAt: new Date(input.requestedAt),
|
||||
cacheHit: false,
|
||||
model,
|
||||
@@ -163,14 +165,76 @@ export const v1ApiRouter = createOpenApiRouter({
|
||||
}),
|
||||
]);
|
||||
|
||||
const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({
|
||||
loggedCallId: newLoggedCallId,
|
||||
// sanitize tags
|
||||
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
|
||||
await createTags(newLoggedCallId, input.tags);
|
||||
}),
|
||||
localTestingOnlyGetLatestLoggedCall: openApiProtectedProc
|
||||
.meta({
|
||||
openapi: {
|
||||
method: "GET",
|
||||
path: "/local-testing-only-get-latest-logged-call",
|
||||
description: "Get the latest logged call (only for local testing)",
|
||||
protect: true, // Make sure to protect this endpoint
|
||||
},
|
||||
})
|
||||
.input(z.void())
|
||||
.output(
|
||||
z
|
||||
.object({
|
||||
createdAt: z.date(),
|
||||
cacheHit: z.boolean(),
|
||||
tags: z.record(z.string().nullable()),
|
||||
modelResponse: z
|
||||
.object({
|
||||
id: z.string(),
|
||||
statusCode: z.number().nullable(),
|
||||
errorMessage: z.string().nullable(),
|
||||
reqPayload: z.unknown(),
|
||||
respPayload: z.unknown(),
|
||||
})
|
||||
.nullable(),
|
||||
})
|
||||
.nullable(),
|
||||
)
|
||||
.mutation(async ({ ctx }) => {
|
||||
if (process.env.NODE_ENV === "production") {
|
||||
throw new Error("This operation is not allowed in production environment");
|
||||
}
|
||||
|
||||
const latestLoggedCall = await prisma.loggedCall.findFirst({
|
||||
where: { projectId: ctx.key.projectId },
|
||||
orderBy: { requestedAt: "desc" },
|
||||
select: {
|
||||
createdAt: true,
|
||||
cacheHit: true,
|
||||
tags: true,
|
||||
modelResponse: {
|
||||
select: {
|
||||
id: true,
|
||||
statusCode: true,
|
||||
errorMessage: true,
|
||||
reqPayload: true,
|
||||
respPayload: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
latestLoggedCall && {
|
||||
...latestLoggedCall,
|
||||
tags: Object.fromEntries(latestLoggedCall.tags.map((tag) => [tag.name, tag.value])),
|
||||
}
|
||||
);
|
||||
}),
|
||||
});
|
||||
|
||||
async function createTags(loggedCallId: string, tags: Record<string, string>) {
|
||||
const tagsToCreate = Object.entries(tags).map(([name, value]) => ({
|
||||
loggedCallId,
|
||||
name: name.replaceAll(/[^a-zA-Z0-9_$]/g, "_"),
|
||||
value,
|
||||
}));
|
||||
await prisma.loggedCallTag.createMany({
|
||||
data: tagsToCreate,
|
||||
});
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,4 +2,4 @@ import cryptoRandomString from "crypto-random-string";
|
||||
|
||||
const KEY_LENGTH = 42;
|
||||
|
||||
export const generateApiKey = () => `opc_${cryptoRandomString({ length: KEY_LENGTH })}`;
|
||||
export const generateApiKey = () => `opk_${cryptoRandomString({ length: KEY_LENGTH })}`;
|
||||
|
||||
@@ -39,7 +39,8 @@
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Extra tags to attach to the call for filtering. Eg { \"userId\": \"123\", \"promptId\": \"populate-title\" }"
|
||||
"description": "Extra tags to attach to the call for filtering. Eg { \"userId\": \"123\", \"promptId\": \"populate-title\" }",
|
||||
"default": {}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@@ -117,7 +118,8 @@
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Extra tags to attach to the call for filtering. Eg { \"userId\": \"123\", \"promptId\": \"populate-title\" }"
|
||||
"description": "Extra tags to attach to the call for filtering. Eg { \"userId\": \"123\", \"promptId\": \"populate-title\" }",
|
||||
"default": {}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@@ -144,6 +146,82 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/local-testing-only-get-latest-logged-call": {
|
||||
"get": {
|
||||
"operationId": "localTestingOnlyGetLatestLoggedCall",
|
||||
"description": "Get the latest logged call (only for local testing)",
|
||||
"security": [
|
||||
{
|
||||
"Authorization": []
|
||||
}
|
||||
],
|
||||
"parameters": [],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"createdAt": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"cacheHit": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"tags": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
}
|
||||
},
|
||||
"modelResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"statusCode": {
|
||||
"type": "number",
|
||||
"nullable": true
|
||||
},
|
||||
"errorMessage": {
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
},
|
||||
"reqPayload": {},
|
||||
"respPayload": {}
|
||||
},
|
||||
"required": [
|
||||
"id",
|
||||
"statusCode",
|
||||
"errorMessage"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"nullable": true
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"createdAt",
|
||||
"cacheHit",
|
||||
"tags",
|
||||
"modelResponse"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/error"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ... import errors
|
||||
from ...client import AuthenticatedClient, Client
|
||||
from ...models.local_testing_only_get_latest_logged_call_response_200 import (
|
||||
LocalTestingOnlyGetLatestLoggedCallResponse200,
|
||||
)
|
||||
from ...types import Response
|
||||
|
||||
|
||||
def _get_kwargs() -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
return {
|
||||
"method": "get",
|
||||
"url": "/local-testing-only-get-latest-logged-call",
|
||||
}
|
||||
|
||||
|
||||
def _parse_response(
|
||||
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
|
||||
) -> Optional[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
_response_200 = response.json()
|
||||
response_200: Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]
|
||||
if _response_200 is None:
|
||||
response_200 = None
|
||||
else:
|
||||
response_200 = LocalTestingOnlyGetLatestLoggedCallResponse200.from_dict(_response_200)
|
||||
|
||||
return response_200
|
||||
if client.raise_on_unexpected_status:
|
||||
raise errors.UnexpectedStatus(response.status_code, response.content)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _build_response(
|
||||
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
|
||||
) -> Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
|
||||
return Response(
|
||||
status_code=HTTPStatus(response.status_code),
|
||||
content=response.content,
|
||||
headers=response.headers,
|
||||
parsed=_parse_response(client=client, response=response),
|
||||
)
|
||||
|
||||
|
||||
def sync_detailed(
|
||||
*,
|
||||
client: AuthenticatedClient,
|
||||
) -> Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
|
||||
"""Get the latest logged call (only for local testing)
|
||||
|
||||
Raises:
|
||||
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
|
||||
httpx.TimeoutException: If the request takes longer than Client.timeout.
|
||||
|
||||
Returns:
|
||||
Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]
|
||||
"""
|
||||
|
||||
kwargs = _get_kwargs()
|
||||
|
||||
response = client.get_httpx_client().request(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _build_response(client=client, response=response)
|
||||
|
||||
|
||||
def sync(
|
||||
*,
|
||||
client: AuthenticatedClient,
|
||||
) -> Optional[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
|
||||
"""Get the latest logged call (only for local testing)
|
||||
|
||||
Raises:
|
||||
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
|
||||
httpx.TimeoutException: If the request takes longer than Client.timeout.
|
||||
|
||||
Returns:
|
||||
Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]
|
||||
"""
|
||||
|
||||
return sync_detailed(
|
||||
client=client,
|
||||
).parsed
|
||||
|
||||
|
||||
async def asyncio_detailed(
|
||||
*,
|
||||
client: AuthenticatedClient,
|
||||
) -> Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
|
||||
"""Get the latest logged call (only for local testing)
|
||||
|
||||
Raises:
|
||||
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
|
||||
httpx.TimeoutException: If the request takes longer than Client.timeout.
|
||||
|
||||
Returns:
|
||||
Response[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]
|
||||
"""
|
||||
|
||||
kwargs = _get_kwargs()
|
||||
|
||||
response = await client.get_async_httpx_client().request(**kwargs)
|
||||
|
||||
return _build_response(client=client, response=response)
|
||||
|
||||
|
||||
async def asyncio(
|
||||
*,
|
||||
client: AuthenticatedClient,
|
||||
) -> Optional[Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]]:
|
||||
"""Get the latest logged call (only for local testing)
|
||||
|
||||
Raises:
|
||||
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
|
||||
httpx.TimeoutException: If the request takes longer than Client.timeout.
|
||||
|
||||
Returns:
|
||||
Optional[LocalTestingOnlyGetLatestLoggedCallResponse200]
|
||||
"""
|
||||
|
||||
return (
|
||||
await asyncio_detailed(
|
||||
client=client,
|
||||
)
|
||||
).parsed
|
||||
@@ -3,6 +3,13 @@
|
||||
from .check_cache_json_body import CheckCacheJsonBody
|
||||
from .check_cache_json_body_tags import CheckCacheJsonBodyTags
|
||||
from .check_cache_response_200 import CheckCacheResponse200
|
||||
from .local_testing_only_get_latest_logged_call_response_200 import LocalTestingOnlyGetLatestLoggedCallResponse200
|
||||
from .local_testing_only_get_latest_logged_call_response_200_model_response import (
|
||||
LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse,
|
||||
)
|
||||
from .local_testing_only_get_latest_logged_call_response_200_tags import (
|
||||
LocalTestingOnlyGetLatestLoggedCallResponse200Tags,
|
||||
)
|
||||
from .report_json_body import ReportJsonBody
|
||||
from .report_json_body_tags import ReportJsonBodyTags
|
||||
|
||||
@@ -10,6 +17,9 @@ __all__ = (
|
||||
"CheckCacheJsonBody",
|
||||
"CheckCacheJsonBodyTags",
|
||||
"CheckCacheResponse200",
|
||||
"LocalTestingOnlyGetLatestLoggedCallResponse200",
|
||||
"LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse",
|
||||
"LocalTestingOnlyGetLatestLoggedCallResponse200Tags",
|
||||
"ReportJsonBody",
|
||||
"ReportJsonBodyTags",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar
|
||||
|
||||
from attrs import define
|
||||
from dateutil.parser import isoparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.local_testing_only_get_latest_logged_call_response_200_model_response import (
|
||||
LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse,
|
||||
)
|
||||
from ..models.local_testing_only_get_latest_logged_call_response_200_tags import (
|
||||
LocalTestingOnlyGetLatestLoggedCallResponse200Tags,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T", bound="LocalTestingOnlyGetLatestLoggedCallResponse200")
|
||||
|
||||
|
||||
@define
|
||||
class LocalTestingOnlyGetLatestLoggedCallResponse200:
|
||||
"""
|
||||
Attributes:
|
||||
created_at (datetime.datetime):
|
||||
cache_hit (bool):
|
||||
tags (LocalTestingOnlyGetLatestLoggedCallResponse200Tags):
|
||||
model_response (Optional[LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse]):
|
||||
"""
|
||||
|
||||
created_at: datetime.datetime
|
||||
cache_hit: bool
|
||||
tags: "LocalTestingOnlyGetLatestLoggedCallResponse200Tags"
|
||||
model_response: Optional["LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse"]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
created_at = self.created_at.isoformat()
|
||||
|
||||
cache_hit = self.cache_hit
|
||||
tags = self.tags.to_dict()
|
||||
|
||||
model_response = self.model_response.to_dict() if self.model_response else None
|
||||
|
||||
field_dict: Dict[str, Any] = {}
|
||||
field_dict.update(
|
||||
{
|
||||
"createdAt": created_at,
|
||||
"cacheHit": cache_hit,
|
||||
"tags": tags,
|
||||
"modelResponse": model_response,
|
||||
}
|
||||
)
|
||||
|
||||
return field_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
|
||||
from ..models.local_testing_only_get_latest_logged_call_response_200_model_response import (
|
||||
LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse,
|
||||
)
|
||||
from ..models.local_testing_only_get_latest_logged_call_response_200_tags import (
|
||||
LocalTestingOnlyGetLatestLoggedCallResponse200Tags,
|
||||
)
|
||||
|
||||
d = src_dict.copy()
|
||||
created_at = isoparse(d.pop("createdAt"))
|
||||
|
||||
cache_hit = d.pop("cacheHit")
|
||||
|
||||
tags = LocalTestingOnlyGetLatestLoggedCallResponse200Tags.from_dict(d.pop("tags"))
|
||||
|
||||
_model_response = d.pop("modelResponse")
|
||||
model_response: Optional[LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse]
|
||||
if _model_response is None:
|
||||
model_response = None
|
||||
else:
|
||||
model_response = LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse.from_dict(_model_response)
|
||||
|
||||
local_testing_only_get_latest_logged_call_response_200 = cls(
|
||||
created_at=created_at,
|
||||
cache_hit=cache_hit,
|
||||
tags=tags,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
return local_testing_only_get_latest_logged_call_response_200
|
||||
@@ -0,0 +1,70 @@
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
||||
|
||||
from attrs import define
|
||||
|
||||
from ..types import UNSET, Unset
|
||||
|
||||
T = TypeVar("T", bound="LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse")
|
||||
|
||||
|
||||
@define
|
||||
class LocalTestingOnlyGetLatestLoggedCallResponse200ModelResponse:
|
||||
"""
|
||||
Attributes:
|
||||
id (str):
|
||||
status_code (Optional[float]):
|
||||
error_message (Optional[str]):
|
||||
req_payload (Union[Unset, Any]):
|
||||
resp_payload (Union[Unset, Any]):
|
||||
"""
|
||||
|
||||
id: str
|
||||
status_code: Optional[float]
|
||||
error_message: Optional[str]
|
||||
req_payload: Union[Unset, Any] = UNSET
|
||||
resp_payload: Union[Unset, Any] = UNSET
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
id = self.id
|
||||
status_code = self.status_code
|
||||
error_message = self.error_message
|
||||
req_payload = self.req_payload
|
||||
resp_payload = self.resp_payload
|
||||
|
||||
field_dict: Dict[str, Any] = {}
|
||||
field_dict.update(
|
||||
{
|
||||
"id": id,
|
||||
"statusCode": status_code,
|
||||
"errorMessage": error_message,
|
||||
}
|
||||
)
|
||||
if req_payload is not UNSET:
|
||||
field_dict["reqPayload"] = req_payload
|
||||
if resp_payload is not UNSET:
|
||||
field_dict["respPayload"] = resp_payload
|
||||
|
||||
return field_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
|
||||
d = src_dict.copy()
|
||||
id = d.pop("id")
|
||||
|
||||
status_code = d.pop("statusCode")
|
||||
|
||||
error_message = d.pop("errorMessage")
|
||||
|
||||
req_payload = d.pop("reqPayload", UNSET)
|
||||
|
||||
resp_payload = d.pop("respPayload", UNSET)
|
||||
|
||||
local_testing_only_get_latest_logged_call_response_200_model_response = cls(
|
||||
id=id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
req_payload=req_payload,
|
||||
resp_payload=resp_payload,
|
||||
)
|
||||
|
||||
return local_testing_only_get_latest_logged_call_response_200_model_response
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar
|
||||
|
||||
from attrs import define, field
|
||||
|
||||
T = TypeVar("T", bound="LocalTestingOnlyGetLatestLoggedCallResponse200Tags")
|
||||
|
||||
|
||||
@define
|
||||
class LocalTestingOnlyGetLatestLoggedCallResponse200Tags:
|
||||
""" """
|
||||
|
||||
additional_properties: Dict[str, Optional[str]] = field(init=False, factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
field_dict: Dict[str, Any] = {}
|
||||
field_dict.update(self.additional_properties)
|
||||
field_dict.update({})
|
||||
|
||||
return field_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
|
||||
d = src_dict.copy()
|
||||
local_testing_only_get_latest_logged_call_response_200_tags = cls()
|
||||
|
||||
local_testing_only_get_latest_logged_call_response_200_tags.additional_properties = d
|
||||
return local_testing_only_get_latest_logged_call_response_200_tags
|
||||
|
||||
@property
|
||||
def additional_keys(self) -> List[str]:
|
||||
return list(self.additional_properties.keys())
|
||||
|
||||
def __getitem__(self, key: str) -> Optional[str]:
|
||||
return self.additional_properties[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Optional[str]) -> None:
|
||||
self.additional_properties[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self.additional_properties[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self.additional_properties
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def merge_streamed_chunks(base: Optional[Any], chunk: Any) -> Any:
|
||||
def merge_openai_chunks(base: Optional[Any], chunk: Any) -> Any:
|
||||
if base is None:
|
||||
return merge_streamed_chunks({**chunk, "choices": []}, chunk)
|
||||
return merge_openai_chunks({**chunk, "choices": []}, chunk)
|
||||
|
||||
choices = base["choices"].copy()
|
||||
for choice in chunk["choices"]:
|
||||
@@ -34,9 +34,7 @@ def merge_streamed_chunks(base: Optional[Any], chunk: Any) -> Any:
|
||||
{**new_choice, "message": {"role": "assistant", **choice["delta"]}}
|
||||
)
|
||||
|
||||
merged = {
|
||||
return {
|
||||
**base,
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
return merged
|
||||
|
||||
@@ -3,9 +3,16 @@ from openai.openai_object import OpenAIObject
|
||||
import time
|
||||
import inspect
|
||||
|
||||
from openpipe.merge_openai_chunks import merge_streamed_chunks
|
||||
from openpipe.merge_openai_chunks import merge_openai_chunks
|
||||
from openpipe.openpipe_meta import OpenPipeMeta
|
||||
|
||||
from .shared import maybe_check_cache, maybe_check_cache_async, report_async, report
|
||||
from .shared import (
|
||||
_should_check_cache,
|
||||
maybe_check_cache,
|
||||
maybe_check_cache_async,
|
||||
report_async,
|
||||
report,
|
||||
)
|
||||
|
||||
|
||||
class WrappedChatCompletion(original_openai.ChatCompletion):
|
||||
@@ -29,9 +36,15 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
|
||||
def _gen():
|
||||
assembled_completion = None
|
||||
for chunk in chat_completion:
|
||||
assembled_completion = merge_streamed_chunks(
|
||||
assembled_completion = merge_openai_chunks(
|
||||
assembled_completion, chunk
|
||||
)
|
||||
|
||||
cache_status = (
|
||||
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
|
||||
)
|
||||
chunk.openpipe = OpenPipeMeta(cache_status=cache_status)
|
||||
|
||||
yield chunk
|
||||
|
||||
received_at = int(time.time() * 1000)
|
||||
@@ -58,6 +71,10 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
cache_status = (
|
||||
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
|
||||
)
|
||||
chat_completion["openpipe"] = OpenPipeMeta(cache_status=cache_status)
|
||||
return chat_completion
|
||||
except Exception as e:
|
||||
received_at = int(time.time() * 1000)
|
||||
@@ -96,21 +113,28 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
|
||||
requested_at = int(time.time() * 1000)
|
||||
|
||||
try:
|
||||
chat_completion = original_openai.ChatCompletion.acreate(*args, **kwargs)
|
||||
chat_completion = await original_openai.ChatCompletion.acreate(
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
if inspect.isgenerator(chat_completion):
|
||||
if inspect.isasyncgen(chat_completion):
|
||||
|
||||
def _gen():
|
||||
async def _gen():
|
||||
assembled_completion = None
|
||||
for chunk in chat_completion:
|
||||
assembled_completion = merge_streamed_chunks(
|
||||
async for chunk in chat_completion:
|
||||
assembled_completion = merge_openai_chunks(
|
||||
assembled_completion, chunk
|
||||
)
|
||||
cache_status = (
|
||||
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
|
||||
)
|
||||
chunk.openpipe = OpenPipeMeta(cache_status=cache_status)
|
||||
|
||||
yield chunk
|
||||
|
||||
received_at = int(time.time() * 1000)
|
||||
|
||||
report_async(
|
||||
await report_async(
|
||||
openpipe_options=openpipe_options,
|
||||
requested_at=requested_at,
|
||||
received_at=received_at,
|
||||
@@ -123,7 +147,7 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
|
||||
else:
|
||||
received_at = int(time.time() * 1000)
|
||||
|
||||
report_async(
|
||||
await report_async(
|
||||
openpipe_options=openpipe_options,
|
||||
requested_at=requested_at,
|
||||
received_at=received_at,
|
||||
@@ -132,12 +156,17 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
cache_status = (
|
||||
"MISS" if _should_check_cache(openpipe_options) else "SKIP"
|
||||
)
|
||||
chat_completion["openpipe"] = OpenPipeMeta(cache_status=cache_status)
|
||||
|
||||
return chat_completion
|
||||
except Exception as e:
|
||||
received_at = int(time.time() * 1000)
|
||||
|
||||
if isinstance(e, original_openai.OpenAIError):
|
||||
report_async(
|
||||
await report_async(
|
||||
openpipe_options=openpipe_options,
|
||||
requested_at=requested_at,
|
||||
received_at=received_at,
|
||||
@@ -147,7 +176,7 @@ class WrappedChatCompletion(original_openai.ChatCompletion):
|
||||
status_code=e.http_status,
|
||||
)
|
||||
else:
|
||||
report_async(
|
||||
await report_async(
|
||||
openpipe_options=openpipe_options,
|
||||
requested_at=requested_at,
|
||||
received_at=received_at,
|
||||
|
||||
7
client-libs/python/openpipe/openpipe_meta.py
Normal file
7
client-libs/python/openpipe/openpipe_meta.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from attr import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenPipeMeta:
|
||||
# Cache status. One of 'HIT', 'MISS', 'SKIP'
|
||||
cache_status: str
|
||||
@@ -1,5 +1,5 @@
|
||||
from openpipe.api_client.api.default import (
|
||||
api_report,
|
||||
report as api_report,
|
||||
check_cache,
|
||||
)
|
||||
from openpipe.api_client.client import AuthenticatedClient
|
||||
|
||||
@@ -1,55 +1,106 @@
|
||||
from functools import reduce
|
||||
from dotenv import load_dotenv
|
||||
from . import openai, configure_openpipe
|
||||
import os
|
||||
import pytest
|
||||
from . import openai, configure_openpipe, configured_client
|
||||
from .api_client.api.default import local_testing_only_get_latest_logged_call
|
||||
from .merge_openai_chunks import merge_openai_chunks
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
def random_string(length):
|
||||
letters = string.ascii_lowercase
|
||||
return "".join(random.choice(letters) for i in range(length))
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
configure_openpipe(
|
||||
base_url="http://localhost:3000/api", api_key=os.getenv("OPENPIPE_API_KEY")
|
||||
base_url="http://localhost:3000/api/v1", api_key=os.getenv("OPENPIPE_API_KEY")
|
||||
)
|
||||
|
||||
|
||||
def last_logged_call():
|
||||
return local_testing_only_get_latest_logged_call.sync(client=configured_client)
|
||||
|
||||
|
||||
def test_sync():
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
messages=[{"role": "system", "content": "count to 3"}],
|
||||
)
|
||||
|
||||
print(completion.choices[0].message.content)
|
||||
last_logged = last_logged_call()
|
||||
assert (
|
||||
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
|
||||
== completion.choices[0].message.content
|
||||
)
|
||||
assert (
|
||||
last_logged.model_response.req_payload["messages"][0]["content"] == "count to 3"
|
||||
)
|
||||
|
||||
assert completion.openpipe.cache_status == "SKIP"
|
||||
|
||||
|
||||
def test_streaming():
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
messages=[{"role": "system", "content": "count to 4"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in completion:
|
||||
print(chunk)
|
||||
merged = reduce(merge_openai_chunks, completion, None)
|
||||
last_logged = last_logged_call()
|
||||
assert (
|
||||
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
|
||||
== merged["choices"][0]["message"]["content"]
|
||||
)
|
||||
|
||||
|
||||
async def test_async():
|
||||
acompletion = await openai.ChatCompletion.acreate(
|
||||
completion = await openai.ChatCompletion.acreate(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "count down from 5"}],
|
||||
)
|
||||
last_logged = last_logged_call()
|
||||
assert (
|
||||
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
|
||||
== completion.choices[0].message.content
|
||||
)
|
||||
assert (
|
||||
last_logged.model_response.req_payload["messages"][0]["content"]
|
||||
== "count down from 5"
|
||||
)
|
||||
|
||||
print(acompletion.choices[0].message.content)
|
||||
assert completion.openpipe.cache_status == "SKIP"
|
||||
|
||||
|
||||
async def test_async_streaming():
|
||||
acompletion = await openai.ChatCompletion.acreate(
|
||||
completion = await openai.ChatCompletion.acreate(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "count down from 5"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in acompletion:
|
||||
print(chunk)
|
||||
merged = None
|
||||
async for chunk in completion:
|
||||
assert chunk.openpipe.cache_status == "SKIP"
|
||||
merged = merge_openai_chunks(merged, chunk)
|
||||
|
||||
last_logged = last_logged_call()
|
||||
|
||||
assert (
|
||||
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
|
||||
== merged["choices"][0]["message"]["content"]
|
||||
)
|
||||
assert (
|
||||
last_logged.model_response.req_payload["messages"][0]["content"]
|
||||
== "count down from 5"
|
||||
)
|
||||
assert merged["openpipe"].cache_status == "SKIP"
|
||||
|
||||
|
||||
def test_sync_with_tags():
|
||||
@@ -58,31 +109,54 @@ def test_sync_with_tags():
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
openpipe={"tags": {"promptId": "testprompt"}},
|
||||
)
|
||||
print("finished")
|
||||
|
||||
print(completion.choices[0].message.content)
|
||||
last_logged = last_logged_call()
|
||||
assert (
|
||||
last_logged.model_response.resp_payload["choices"][0]["message"]["content"]
|
||||
== completion.choices[0].message.content
|
||||
)
|
||||
print(last_logged.tags)
|
||||
assert last_logged.tags["promptId"] == "testprompt"
|
||||
assert last_logged.tags["$sdk"] == "python"
|
||||
|
||||
|
||||
def test_bad_call():
|
||||
try:
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo-blaster",
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
stream=True,
|
||||
)
|
||||
assert False
|
||||
except Exception as e:
|
||||
pass
|
||||
last_logged = last_logged_call()
|
||||
print(last_logged)
|
||||
assert (
|
||||
last_logged.model_response.error_message
|
||||
== "The model `gpt-3.5-turbo-blaster` does not exist"
|
||||
)
|
||||
assert last_logged.model_response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.focus
|
||||
async def test_caching():
|
||||
messages = [{"role": "system", "content": f"repeat '{random_string(10)}'"}]
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
messages=messages,
|
||||
openpipe={"cache": True},
|
||||
)
|
||||
assert completion.openpipe.cache_status == "MISS"
|
||||
|
||||
first_logged = last_logged_call()
|
||||
assert (
|
||||
completion.choices[0].message.content
|
||||
== first_logged.model_response.resp_payload["choices"][0]["message"]["content"]
|
||||
)
|
||||
|
||||
completion2 = await openai.ChatCompletion.acreate(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": "count to 10"}],
|
||||
openpipe={"cache": True},
|
||||
)
|
||||
|
||||
print(completion2)
|
||||
assert completion2.openpipe.cache_status == "HIT"
|
||||
|
||||
@@ -99,6 +99,74 @@ export interface CheckCacheRequest {
|
||||
*/
|
||||
'tags'?: { [key: string]: string; };
|
||||
}
|
||||
/**
|
||||
*
|
||||
* @export
|
||||
* @interface LocalTestingOnlyGetLatestLoggedCall200Response
|
||||
*/
|
||||
export interface LocalTestingOnlyGetLatestLoggedCall200Response {
|
||||
/**
|
||||
*
|
||||
* @type {string}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200Response
|
||||
*/
|
||||
'createdAt': string;
|
||||
/**
|
||||
*
|
||||
* @type {boolean}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200Response
|
||||
*/
|
||||
'cacheHit': boolean;
|
||||
/**
|
||||
*
|
||||
* @type {{ [key: string]: string; }}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200Response
|
||||
*/
|
||||
'tags': { [key: string]: string; };
|
||||
/**
|
||||
*
|
||||
* @type {LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200Response
|
||||
*/
|
||||
'modelResponse': LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse | null;
|
||||
}
|
||||
/**
|
||||
*
|
||||
* @export
|
||||
* @interface LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse
|
||||
*/
|
||||
export interface LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse {
|
||||
/**
|
||||
*
|
||||
* @type {string}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse
|
||||
*/
|
||||
'id': string;
|
||||
/**
|
||||
*
|
||||
* @type {number}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse
|
||||
*/
|
||||
'statusCode': number | null;
|
||||
/**
|
||||
*
|
||||
* @type {string}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse
|
||||
*/
|
||||
'errorMessage': string | null;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse
|
||||
*/
|
||||
'reqPayload'?: any;
|
||||
/**
|
||||
*
|
||||
* @type {any}
|
||||
* @memberof LocalTestingOnlyGetLatestLoggedCall200ResponseModelResponse
|
||||
*/
|
||||
'respPayload'?: any;
|
||||
}
|
||||
/**
|
||||
*
|
||||
* @export
|
||||
@@ -194,6 +262,39 @@ export const DefaultApiAxiosParamCreator = function (configuration?: Configurati
|
||||
options: localVarRequestOptions,
|
||||
};
|
||||
},
|
||||
/**
|
||||
* Get the latest logged call (only for local testing)
|
||||
* @param {*} [options] Override http request option.
|
||||
* @throws {RequiredError}
|
||||
*/
|
||||
localTestingOnlyGetLatestLoggedCall: async (options: AxiosRequestConfig = {}): Promise<RequestArgs> => {
|
||||
const localVarPath = `/local-testing-only-get-latest-logged-call`;
|
||||
// use dummy base URL string because the URL constructor only accepts absolute URLs.
|
||||
const localVarUrlObj = new URL(localVarPath, DUMMY_BASE_URL);
|
||||
let baseOptions;
|
||||
if (configuration) {
|
||||
baseOptions = configuration.baseOptions;
|
||||
}
|
||||
|
||||
const localVarRequestOptions = { method: 'GET', ...baseOptions, ...options};
|
||||
const localVarHeaderParameter = {} as any;
|
||||
const localVarQueryParameter = {} as any;
|
||||
|
||||
// authentication Authorization required
|
||||
// http bearer authentication required
|
||||
await setBearerAuthToObject(localVarHeaderParameter, configuration)
|
||||
|
||||
|
||||
|
||||
setSearchParams(localVarUrlObj, localVarQueryParameter);
|
||||
let headersFromBaseOptions = baseOptions && baseOptions.headers ? baseOptions.headers : {};
|
||||
localVarRequestOptions.headers = {...localVarHeaderParameter, ...headersFromBaseOptions, ...options.headers};
|
||||
|
||||
return {
|
||||
url: toPathString(localVarUrlObj),
|
||||
options: localVarRequestOptions,
|
||||
};
|
||||
},
|
||||
/**
|
||||
* Report an API call
|
||||
* @param {ReportRequest} reportRequest
|
||||
@@ -253,6 +354,15 @@ export const DefaultApiFp = function(configuration?: Configuration) {
|
||||
const localVarAxiosArgs = await localVarAxiosParamCreator.checkCache(checkCacheRequest, options);
|
||||
return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration);
|
||||
},
|
||||
/**
|
||||
* Get the latest logged call (only for local testing)
|
||||
* @param {*} [options] Override http request option.
|
||||
* @throws {RequiredError}
|
||||
*/
|
||||
async localTestingOnlyGetLatestLoggedCall(options?: AxiosRequestConfig): Promise<(axios?: AxiosInstance, basePath?: string) => AxiosPromise<LocalTestingOnlyGetLatestLoggedCall200Response>> {
|
||||
const localVarAxiosArgs = await localVarAxiosParamCreator.localTestingOnlyGetLatestLoggedCall(options);
|
||||
return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration);
|
||||
},
|
||||
/**
|
||||
* Report an API call
|
||||
* @param {ReportRequest} reportRequest
|
||||
@@ -282,6 +392,14 @@ export const DefaultApiFactory = function (configuration?: Configuration, basePa
|
||||
checkCache(checkCacheRequest: CheckCacheRequest, options?: any): AxiosPromise<CheckCache200Response> {
|
||||
return localVarFp.checkCache(checkCacheRequest, options).then((request) => request(axios, basePath));
|
||||
},
|
||||
/**
|
||||
* Get the latest logged call (only for local testing)
|
||||
* @param {*} [options] Override http request option.
|
||||
* @throws {RequiredError}
|
||||
*/
|
||||
localTestingOnlyGetLatestLoggedCall(options?: any): AxiosPromise<LocalTestingOnlyGetLatestLoggedCall200Response> {
|
||||
return localVarFp.localTestingOnlyGetLatestLoggedCall(options).then((request) => request(axios, basePath));
|
||||
},
|
||||
/**
|
||||
* Report an API call
|
||||
* @param {ReportRequest} reportRequest
|
||||
@@ -312,6 +430,16 @@ export class DefaultApi extends BaseAPI {
|
||||
return DefaultApiFp(this.configuration).checkCache(checkCacheRequest, options).then((request) => request(this.axios, this.basePath));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the latest logged call (only for local testing)
|
||||
* @param {*} [options] Override http request option.
|
||||
* @throws {RequiredError}
|
||||
* @memberof DefaultApi
|
||||
*/
|
||||
public localTestingOnlyGetLatestLoggedCall(options?: AxiosRequestConfig) {
|
||||
return DefaultApiFp(this.configuration).localTestingOnlyGetLatestLoggedCall(options).then((request) => request(this.axios, this.basePath));
|
||||
}
|
||||
|
||||
/**
|
||||
* Report an API call
|
||||
* @param {ReportRequest} reportRequest
|
||||
|
||||
Reference in New Issue
Block a user