Add checkCache and report routes
This commit is contained in:
22
app/src/pages/api/[...trpc].ts
Normal file
22
app/src/pages/api/[...trpc].ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { type NextApiRequest, type NextApiResponse } from "next";
|
||||
import cors from "nextjs-cors";
|
||||
import { createOpenApiNextHandler } from "trpc-openapi";
|
||||
import { createProcedureCache } from "trpc-openapi/dist/adapters/node-http/procedures";
|
||||
import { appRouter } from "~/server/api/root.router";
|
||||
import { createTRPCContext } from "~/server/api/trpc";
|
||||
|
||||
const openApiHandler = createOpenApiNextHandler({
|
||||
router: appRouter,
|
||||
createContext: createTRPCContext,
|
||||
});
|
||||
|
||||
const cache = createProcedureCache(appRouter);
|
||||
|
||||
const handler = async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
// Setup CORS
|
||||
await cors(req, res);
|
||||
|
||||
return openApiHandler(req, res);
|
||||
};
|
||||
|
||||
export default handler;
|
||||
16
app/src/pages/api/openapi.json.ts
Normal file
16
app/src/pages/api/openapi.json.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { type NextApiRequest, type NextApiResponse } from "next";
|
||||
import { generateOpenApiDocument } from "trpc-openapi";
|
||||
import { appRouter } from "~/server/api/root.router";
|
||||
|
||||
export const openApiDocument = generateOpenApiDocument(appRouter, {
|
||||
title: "OpenPipe API",
|
||||
description: "The public API for reporting API calls to OpenPipe",
|
||||
version: "0.1.0",
|
||||
baseUrl: "https://app.openpipe.ai/api",
|
||||
});
|
||||
// Respond with our OpenAPI schema
|
||||
const hander = (req: NextApiRequest, res: NextApiResponse) => {
|
||||
res.status(200).send(openApiDocument);
|
||||
};
|
||||
|
||||
export default hander;
|
||||
197
app/src/server/api/routers/externalApi.router.ts
Normal file
197
app/src/server/api/routers/externalApi.router.ts
Normal file
@@ -0,0 +1,197 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { z } from "zod";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { hashRequest } from "~/server/utils/hashObject";
|
||||
|
||||
const reqValidator = z.object({
|
||||
model: z.string(),
|
||||
messages: z.array(z.any()),
|
||||
});
|
||||
|
||||
const respValidator = z.object({
|
||||
id: z.string(),
|
||||
model: z.string(),
|
||||
usage: z.object({
|
||||
total_tokens: z.number(),
|
||||
prompt_tokens: z.number(),
|
||||
completion_tokens: z.number(),
|
||||
}),
|
||||
choices: z.array(
|
||||
z.object({
|
||||
finish_reason: z.string(),
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export const externalApiRouter = createTRPCRouter({
|
||||
checkCache: publicProcedure
|
||||
.meta({
|
||||
openapi: {
|
||||
method: "POST",
|
||||
path: "/v1/check-cache",
|
||||
description: "Check if a prompt is cached",
|
||||
},
|
||||
})
|
||||
.input(
|
||||
z.object({
|
||||
startTime: z.number().describe("Unix timestamp in milliseconds"),
|
||||
reqPayload: z.unknown().describe("JSON-encoded request payload"),
|
||||
tags: z
|
||||
.record(z.string())
|
||||
.optional()
|
||||
.describe(
|
||||
'Extra tags to attach to the call for filtering. Eg { "userId": "123", "promptId": "populate-title" }',
|
||||
),
|
||||
}),
|
||||
)
|
||||
.output(
|
||||
z.object({
|
||||
respPayload: z.unknown().optional().describe("JSON-encoded response payload"),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const apiKey = ctx.apiKey;
|
||||
if (!apiKey) {
|
||||
throw new Error("Missing API key");
|
||||
}
|
||||
const key = await prisma.apiKey.findUnique({
|
||||
where: { apiKey },
|
||||
});
|
||||
if (!key) {
|
||||
throw new Error("Invalid API key");
|
||||
}
|
||||
const reqPayload = await reqValidator.spa(input.reqPayload);
|
||||
const cacheKey = hashRequest(key.organizationId, reqPayload as JsonValue);
|
||||
|
||||
const existingResponse = await prisma.loggedCallModelResponse.findFirst({
|
||||
where: {
|
||||
cacheKey,
|
||||
},
|
||||
include: {
|
||||
originalLoggedCall: true,
|
||||
},
|
||||
orderBy: {
|
||||
startTime: "desc",
|
||||
}
|
||||
});
|
||||
|
||||
if (!existingResponse) return { respPayload: null };
|
||||
|
||||
await prisma.loggedCall.create({
|
||||
data: {
|
||||
organizationId: key.organizationId,
|
||||
startTime: new Date(input.startTime),
|
||||
cacheHit: false,
|
||||
modelResponseId: existingResponse.id,
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
respPayload: existingResponse.respPayload,
|
||||
};
|
||||
}),
|
||||
|
||||
report: publicProcedure
|
||||
.meta({
|
||||
openapi: {
|
||||
method: "POST",
|
||||
path: "/v1/report",
|
||||
description: "Report an API call",
|
||||
},
|
||||
})
|
||||
.input(
|
||||
z.object({
|
||||
startTime: z.number().describe("Unix timestamp in milliseconds"),
|
||||
endTime: z.number().describe("Unix timestamp in milliseconds"),
|
||||
reqPayload: z.unknown().describe("JSON-encoded request payload"),
|
||||
respPayload: z.unknown().optional().describe("JSON-encoded response payload"),
|
||||
respStatus: z.number().optional().describe("HTTP status code of response"),
|
||||
error: z.string().optional().describe("User-friendly error message"),
|
||||
tags: z
|
||||
.record(z.string())
|
||||
.optional()
|
||||
.describe(
|
||||
'Extra tags to attach to the call for filtering. Eg { "userId": "123", "promptId": "populate-title" }',
|
||||
),
|
||||
}),
|
||||
)
|
||||
.output(z.void())
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const apiKey = ctx.apiKey;
|
||||
if (!apiKey) {
|
||||
throw new Error("Missing API key");
|
||||
}
|
||||
const key = await prisma.apiKey.findUnique({
|
||||
where: { apiKey },
|
||||
});
|
||||
if (!key) {
|
||||
throw new Error("Invalid API key");
|
||||
}
|
||||
const reqPayload = await reqValidator.spa(input.reqPayload);
|
||||
const respPayload = await respValidator.spa(input.respPayload);
|
||||
|
||||
const requestHash = hashRequest(key.organizationId, reqPayload as JsonValue);
|
||||
|
||||
const newLoggedCallId = uuidv4();
|
||||
const newModelResponseId = uuidv4();
|
||||
|
||||
const usage = respPayload.success ? respPayload.data.usage : undefined;
|
||||
|
||||
await prisma.$transaction([
|
||||
prisma.loggedCall.create({
|
||||
data: {
|
||||
id: newLoggedCallId,
|
||||
organizationId: key.organizationId,
|
||||
startTime: new Date(input.startTime),
|
||||
cacheHit: false,
|
||||
modelResponseId: newModelResponseId,
|
||||
},
|
||||
}),
|
||||
prisma.loggedCallModelResponse.create({
|
||||
data: {
|
||||
id: newModelResponseId,
|
||||
originalLoggedCallId: newLoggedCallId,
|
||||
startTime: new Date(input.startTime),
|
||||
endTime: new Date(input.endTime),
|
||||
reqPayload: input.reqPayload as Prisma.InputJsonValue,
|
||||
respPayload: input.respPayload as Prisma.InputJsonValue,
|
||||
respStatus: input.respStatus,
|
||||
error: input.error,
|
||||
durationMs: input.endTime - input.startTime,
|
||||
...(respPayload.success
|
||||
? {
|
||||
cacheKey: requestHash,
|
||||
inputTokens: usage ? usage.prompt_tokens : undefined,
|
||||
outputTokens: usage ? usage.completion_tokens : undefined,
|
||||
model: respPayload.data.model,
|
||||
}
|
||||
: null),
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
if (input.tags) {
|
||||
const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({
|
||||
loggedCallId: newLoggedCallId,
|
||||
// sanitize tags
|
||||
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
|
||||
value,
|
||||
}));
|
||||
|
||||
if (reqPayload.success) {
|
||||
tagsToCreate.push({
|
||||
loggedCallId: newLoggedCallId,
|
||||
name: "$model",
|
||||
value: reqPayload.data.model,
|
||||
});
|
||||
}
|
||||
await prisma.loggedCallTag.createMany({
|
||||
data: tagsToCreate,
|
||||
});
|
||||
}
|
||||
}),
|
||||
});
|
||||
@@ -11,6 +11,7 @@ import { initTRPC, TRPCError } from "@trpc/server";
|
||||
import { type CreateNextContextOptions } from "@trpc/server/adapters/next";
|
||||
import { type Session } from "next-auth";
|
||||
import superjson from "superjson";
|
||||
import { type OpenApiMeta } from "trpc-openapi";
|
||||
import { ZodError } from "zod";
|
||||
import { getServerAuthSession } from "~/server/auth";
|
||||
import { prisma } from "~/server/db";
|
||||
@@ -26,6 +27,7 @@ import { capturePath } from "~/utils/analytics/serverAnalytics";
|
||||
|
||||
type CreateContextOptions = {
|
||||
session: Session | null;
|
||||
apiKey: string | null;
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
||||
@@ -44,6 +46,7 @@ const noOp = () => {};
|
||||
export const createInnerTRPCContext = (opts: CreateContextOptions) => {
|
||||
return {
|
||||
session: opts.session,
|
||||
apiKey: opts.apiKey,
|
||||
prisma,
|
||||
markAccessControlRun: noOp,
|
||||
};
|
||||
@@ -61,8 +64,13 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
|
||||
// Get the session from the server using the getServerSession wrapper function
|
||||
const session = await getServerAuthSession({ req, res });
|
||||
|
||||
const apiKey = req.headers["x-openpipe-api-key"] as string | null;
|
||||
|
||||
console.log('api key is', apiKey)
|
||||
|
||||
return createInnerTRPCContext({
|
||||
session,
|
||||
apiKey,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -76,18 +84,21 @@ export const createTRPCContext = async (opts: CreateNextContextOptions) => {
|
||||
|
||||
export type TRPCContext = Awaited<ReturnType<typeof createTRPCContext>>;
|
||||
|
||||
const t = initTRPC.context<typeof createTRPCContext>().create({
|
||||
transformer: superjson,
|
||||
errorFormatter({ shape, error }) {
|
||||
return {
|
||||
...shape,
|
||||
data: {
|
||||
...shape.data,
|
||||
zodError: error.cause instanceof ZodError ? error.cause.flatten() : null,
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
const t = initTRPC
|
||||
.context<typeof createTRPCContext>()
|
||||
.meta<OpenApiMeta>()
|
||||
.create({
|
||||
transformer: superjson,
|
||||
errorFormatter({ shape, error }) {
|
||||
return {
|
||||
...shape,
|
||||
data: {
|
||||
...shape.data,
|
||||
zodError: error.cause instanceof ZodError ? error.cause.flatten() : null,
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* 3. ROUTER & PROCEDURE (THE IMPORTANT BIT)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { type Prisma } from "@prisma/client";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import { JsonValue, type JsonObject } from "type-fest";
|
||||
import modelProviders from "~/modelProviders/modelProviders";
|
||||
import { prisma } from "~/server/db";
|
||||
import { wsConnection } from "~/utils/wsConnection";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import hashPrompt from "../utils/hashPrompt";
|
||||
import hashObject from "../utils/hashObject";
|
||||
import defineTask from "./defineTask";
|
||||
import parsePromptConstructor from "~/promptConstructor/parse";
|
||||
|
||||
@@ -99,7 +99,7 @@ export const queryModel = defineTask<QueryModelJob>("queryModel", async (task) =
|
||||
}
|
||||
: null;
|
||||
|
||||
const inputHash = hashPrompt(prompt);
|
||||
const inputHash = hashObject(prompt as JsonValue);
|
||||
|
||||
let modelResponse = await prisma.modelResponse.create({
|
||||
data: {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { type JsonObject } from "type-fest";
|
||||
import hashPrompt from "./hashPrompt";
|
||||
import hashObject from "./hashObject";
|
||||
import { omit } from "lodash-es";
|
||||
import { queueQueryModel } from "../tasks/queryModel.task";
|
||||
import parsePromptConstructor from "~/promptConstructor/parse";
|
||||
@@ -57,7 +57,7 @@ export const generateNewCell = async (
|
||||
return;
|
||||
}
|
||||
|
||||
const inputHash = hashPrompt(parsedConstructFn);
|
||||
const inputHash = hashObject(parsedConstructFn);
|
||||
|
||||
cell = await prisma.scenarioVariantCell.create({
|
||||
data: {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import crypto from "crypto";
|
||||
import { type JsonValue } from "type-fest";
|
||||
import { ParsedPromptConstructor } from "~/promptConstructor/parse";
|
||||
|
||||
function sortKeys(obj: JsonValue): JsonValue {
|
||||
if (typeof obj !== "object" || obj === null) {
|
||||
@@ -25,9 +24,17 @@ function sortKeys(obj: JsonValue): JsonValue {
|
||||
return sortedObj;
|
||||
}
|
||||
|
||||
export default function hashPrompt(prompt: ParsedPromptConstructor<any>): string {
|
||||
export function hashRequest(organizationId: string, reqPayload: JsonValue): string {
|
||||
const obj = {
|
||||
organizationId,
|
||||
reqPayload,
|
||||
};
|
||||
return hashObject(obj);
|
||||
}
|
||||
|
||||
export default function hashObject(obj: JsonValue): string {
|
||||
// Sort object keys recursively
|
||||
const sortedObj = sortKeys(prompt as unknown as JsonValue);
|
||||
const sortedObj = sortKeys(obj);
|
||||
|
||||
// Convert to JSON and hash it
|
||||
const str = JSON.stringify(sortedObj);
|
||||
Reference in New Issue
Block a user