Add checkCache and report routes

This commit is contained in:
David Corbitt
2023-08-05 20:37:16 -07:00
parent 9e859c199e
commit 7f8b574c9f
12 changed files with 489 additions and 29 deletions

View File

@@ -0,0 +1,22 @@
import { type NextApiRequest, type NextApiResponse } from "next";
import cors from "nextjs-cors";
import { createOpenApiNextHandler } from "trpc-openapi";
import { createProcedureCache } from "trpc-openapi/dist/adapters/node-http/procedures";
import { appRouter } from "~/server/api/root.router";
import { createTRPCContext } from "~/server/api/trpc";
const openApiHandler = createOpenApiNextHandler({
router: appRouter,
createContext: createTRPCContext,
});
const cache = createProcedureCache(appRouter);
const handler = async (req: NextApiRequest, res: NextApiResponse) => {
// Setup CORS
await cors(req, res);
return openApiHandler(req, res);
};
export default handler;

View File

@@ -0,0 +1,16 @@
import { type NextApiRequest, type NextApiResponse } from "next";
import { generateOpenApiDocument } from "trpc-openapi";
import { appRouter } from "~/server/api/root.router";
export const openApiDocument = generateOpenApiDocument(appRouter, {
title: "OpenPipe API",
description: "The public API for reporting API calls to OpenPipe",
version: "0.1.0",
baseUrl: "https://app.openpipe.ai/api",
});
// Respond with our OpenAPI schema
const hander = (req: NextApiRequest, res: NextApiResponse) => {
res.status(200).send(openApiDocument);
};
export default hander;

View File

@@ -0,0 +1,197 @@
import { type Prisma } from "@prisma/client";
import { type JsonValue } from "type-fest";
import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { hashRequest } from "~/server/utils/hashObject";
const reqValidator = z.object({
model: z.string(),
messages: z.array(z.any()),
});
const respValidator = z.object({
id: z.string(),
model: z.string(),
usage: z.object({
total_tokens: z.number(),
prompt_tokens: z.number(),
completion_tokens: z.number(),
}),
choices: z.array(
z.object({
finish_reason: z.string(),
}),
),
});
export const externalApiRouter = createTRPCRouter({
checkCache: publicProcedure
.meta({
openapi: {
method: "POST",
path: "/v1/check-cache",
description: "Check if a prompt is cached",
},
})
.input(
z.object({
startTime: z.number().describe("Unix timestamp in milliseconds"),
reqPayload: z.unknown().describe("JSON-encoded request payload"),
tags: z
.record(z.string())
.optional()
.describe(
'Extra tags to attach to the call for filtering. Eg { "userId": "123", "promptId": "populate-title" }',
),
}),
)
.output(
z.object({
respPayload: z.unknown().optional().describe("JSON-encoded response payload"),
}),
)
.mutation(async ({ input, ctx }) => {
const apiKey = ctx.apiKey;
if (!apiKey) {
throw new Error("Missing API key");
}
const key = await prisma.apiKey.findUnique({
where: { apiKey },
});
if (!key) {
throw new Error("Invalid API key");
}
const reqPayload = await reqValidator.spa(input.reqPayload);
const cacheKey = hashRequest(key.organizationId, reqPayload as JsonValue);
const existingResponse = await prisma.loggedCallModelResponse.findFirst({
where: {
cacheKey,
},
include: {
originalLoggedCall: true,
},
orderBy: {
startTime: "desc",
}
});
if (!existingResponse) return { respPayload: null };
await prisma.loggedCall.create({
data: {
organizationId: key.organizationId,
startTime: new Date(input.startTime),
cacheHit: false,
modelResponseId: existingResponse.id,
}
})
return {
respPayload: existingResponse.respPayload,
};
}),
report: publicProcedure
.meta({
openapi: {
method: "POST",
path: "/v1/report",
description: "Report an API call",
},
})
.input(
z.object({
startTime: z.number().describe("Unix timestamp in milliseconds"),
endTime: z.number().describe("Unix timestamp in milliseconds"),
reqPayload: z.unknown().describe("JSON-encoded request payload"),
respPayload: z.unknown().optional().describe("JSON-encoded response payload"),
respStatus: z.number().optional().describe("HTTP status code of response"),
error: z.string().optional().describe("User-friendly error message"),
tags: z
.record(z.string())
.optional()
.describe(
'Extra tags to attach to the call for filtering. Eg { "userId": "123", "promptId": "populate-title" }',
),
}),
)
.output(z.void())
.mutation(async ({ input, ctx }) => {
const apiKey = ctx.apiKey;
if (!apiKey) {
throw new Error("Missing API key");
}
const key = await prisma.apiKey.findUnique({
where: { apiKey },
});
if (!key) {
throw new Error("Invalid API key");
}
const reqPayload = await reqValidator.spa(input.reqPayload);
const respPayload = await respValidator.spa(input.respPayload);
const requestHash = hashRequest(key.organizationId, reqPayload as JsonValue);
const newLoggedCallId = uuidv4();
const newModelResponseId = uuidv4();
const usage = respPayload.success ? respPayload.data.usage : undefined;
await prisma.$transaction([
prisma.loggedCall.create({
data: {
id: newLoggedCallId,
organizationId: key.organizationId,
startTime: new Date(input.startTime),
cacheHit: false,
modelResponseId: newModelResponseId,
},
}),
prisma.loggedCallModelResponse.create({
data: {
id: newModelResponseId,
originalLoggedCallId: newLoggedCallId,
startTime: new Date(input.startTime),
endTime: new Date(input.endTime),
reqPayload: input.reqPayload as Prisma.InputJsonValue,
respPayload: input.respPayload as Prisma.InputJsonValue,
respStatus: input.respStatus,
error: input.error,
durationMs: input.endTime - input.startTime,
...(respPayload.success
? {
cacheKey: requestHash,
inputTokens: usage ? usage.prompt_tokens : undefined,
outputTokens: usage ? usage.completion_tokens : undefined,
model: respPayload.data.model,
}
: null),
},
}),
]);
if (input.tags) {
const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({
loggedCallId: newLoggedCallId,
// sanitize tags
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
value,
}));
if (reqPayload.success) {
tagsToCreate.push({
loggedCallId: newLoggedCallId,
name: "$model",
value: reqPayload.data.model,
});
}
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
});
}
}),
});

View File

@@ -11,6 +11,7 @@ import { initTRPC, TRPCError } from "@trpc/server";
import { type CreateNextContextOptions } from "@trpc/server/adapters/next";
import { type Session } from "next-auth";
import superjson from "superjson";
import { type OpenApiMeta } from "trpc-openapi";
import { ZodError } from "zod";
import { getServerAuthSession } from "~/server/auth";
import { prisma } from "~/server/db";
@@ -26,6 +27,7 @@ import { capturePath } from "~/utils/analytics/serverAnalytics";
type CreateContextOptions = {
session: Session | null;
apiKey: string | null;
};
// eslint-disable-next-line @typescript-eslint/no-empty-function
@@ -44,6 +46,7 @@ const noOp = () => {};
export const createInnerTRPCContext = (opts: CreateContextOptions) => {
return {
session: opts.session,
apiKey: opts.apiKey,
prisma,
markAccessControlRun: noOp,
};
@@ -61,8 +64,13 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
// Get the session from the server using the getServerSession wrapper function
const session = await getServerAuthSession({ req, res });
const apiKey = req.headers["x-openpipe-api-key"] as string | null;
console.log('api key is', apiKey)
return createInnerTRPCContext({
session,
apiKey,
});
};
@@ -76,18 +84,21 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
const t = initTRPC.context<typeof createTRPCContext>().create({
transformer: superjson,
errorFormatter({ shape, error }) {
return {
...shape,
data: {
...shape.data,
zodError: error.cause instanceof ZodError ? error.cause.flatten() : null,
},
};
},
});
const t = initTRPC
.context<typeof createTRPCContext>()
.meta<OpenApiMeta>()
.create({
transformer: superjson,
errorFormatter({ shape, error }) {
return {
...shape,
data: {
...shape.data,
zodError: error.cause instanceof ZodError ? error.cause.flatten() : null,
},
};
},
});
/**
* 3. ROUTER & PROCEDURE (THE IMPORTANT BIT)

View File

@@ -1,10 +1,10 @@
import { type Prisma } from "@prisma/client";
import { type JsonObject } from "type-fest";
import { JsonValue, type JsonObject } from "type-fest";
import modelProviders from "~/modelProviders/modelProviders";
import { prisma } from "~/server/db";
import { wsConnection } from "~/utils/wsConnection";
import { runEvalsForOutput } from "../utils/evaluations";
import hashPrompt from "../utils/hashPrompt";
import hashObject from "../utils/hashObject";
import defineTask from "./defineTask";
import parsePromptConstructor from "~/promptConstructor/parse";
@@ -99,7 +99,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
}
: null;
const inputHash = hashPrompt(prompt);
const inputHash = hashObject(prompt as JsonValue);
let modelResponse = await prisma.modelResponse.create({
data: {

View File

@@ -1,7 +1,7 @@
import { Prisma } from "@prisma/client";
import { prisma } from "../db";
import { type JsonObject } from "type-fest";
import hashPrompt from "./hashPrompt";
import hashObject from "./hashObject";
import { omit } from "lodash-es";
import { queueQueryModel } from "../tasks/queryModel.task";
import parsePromptConstructor from "~/promptConstructor/parse";
@@ -57,7 +57,7 @@ export const generateNewCell = async (
return;
}
const inputHash = hashPrompt(parsedConstructFn);
const inputHash = hashObject(parsedConstructFn);
cell = await prisma.scenarioVariantCell.create({
data: {

View File

@@ -1,6 +1,5 @@
import crypto from "crypto";
import { type JsonValue } from "type-fest";
import { ParsedPromptConstructor } from "~/promptConstructor/parse";
function sortKeys(obj: JsonValue): JsonValue {
if (typeof obj !== "object" || obj === null) {
@@ -25,9 +24,17 @@ function sortKeys(obj: JsonValue): JsonValue {
return sortedObj;
}
export default function hashPrompt(prompt: ParsedPromptConstructor<any>): string {
export function hashRequest(organizationId: string, reqPayload: JsonValue): string {
const obj = {
organizationId,
reqPayload,
};
return hashObject(obj);
}
export default function hashObject(obj: JsonValue): string {
// Sort object keys recursively
const sortedObj = sortKeys(prompt as unknown as JsonValue);
const sortedObj = sortKeys(obj);
// Convert to JSON and hash it
const str = JSON.stringify(sortedObj);