Record model and cost when reporting logs (#136)

* Rename prompt and completion tokens to input and output tokens

* Add getUsage function

* Record model and cost when reporting log

* Remove unused imports

* Move UsageGraph to its own component

* Standardize model response fields

* Fix types
This commit is contained in:
arcticfly
2023-08-11 13:56:47 -07:00
committed by GitHub
parent f270579283
commit 8d1ee62ff1
24 changed files with 295 additions and 199 deletions

View File

@@ -24,9 +24,9 @@ export const dashboardRouter = createTRPCRouter({
)
.where("projectId", "=", input.projectId)
.select(({ fn }) => [
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."startTime")`.as("period"),
sql<Date>`date_trunc('day', "LoggedCallModelResponse"."requestedAt")`.as("period"),
sql<number>`count("LoggedCall"."id")::int`.as("numQueries"),
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"),
fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
])
.groupBy("period")
.orderBy("period")
@@ -57,7 +57,7 @@ export const dashboardRouter = createTRPCRouter({
backfilledPeriods.unshift({
period: dayjs(dayToMatch).toDate(),
numQueries: 0,
totalCost: 0,
cost: 0,
});
}
dayToMatch = dayToMatch.subtract(1, "day");
@@ -72,7 +72,7 @@ export const dashboardRouter = createTRPCRouter({
)
.where("projectId", "=", input.projectId)
.select(({ fn }) => [
fn.sum(fn.coalesce("LoggedCallModelResponse.totalCost", sql<number>`0`)).as("totalCost"),
fn.sum(fn.coalesce("LoggedCallModelResponse.cost", sql<number>`0`)).as("cost"),
fn.count("LoggedCall.id").as("numQueries"),
])
.executeTakeFirst();
@@ -85,8 +85,8 @@ export const dashboardRouter = createTRPCRouter({
"LoggedCall.id",
"LoggedCallModelResponse.originalLoggedCallId",
)
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "respStatus as code"])
.where("respStatus", ">", 200)
.select(({ fn }) => [fn.count("LoggedCall.id").as("count"), "statusCode as code"])
.where("statusCode", ">", 200)
.groupBy("code")
.orderBy("count", "desc")
.execute();
@@ -108,7 +108,7 @@ export const dashboardRouter = createTRPCRouter({
// https://discord.com/channels/966627436387266600/1122258443886153758/1122258443886153758
loggedCalls: publicProcedure.input(z.object({})).query(async ({ input }) => {
const loggedCalls = await prisma.loggedCall.findMany({
orderBy: { startTime: "desc" },
orderBy: { requestedAt: "desc" },
include: { tags: true, modelResponse: true },
take: 20,
});

View File

@@ -227,7 +227,7 @@ export const experimentsRouter = createTRPCRouter({
...modelResponseData,
id: newModelResponseId,
scenarioVariantCellId: newCellId,
output: (modelResponse.output as Prisma.InputJsonValue) ?? undefined,
respPayload: (modelResponse.respPayload as Prisma.InputJsonValue) ?? undefined,
});
for (const evaluation of outputEvaluations) {
outputEvaluationsToCreate.push({

View File

@@ -7,6 +7,11 @@ import { TRPCError } from "@trpc/server";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db";
import { hashRequest } from "~/server/utils/hashObject";
import modelProvider from "~/modelProviders/openai-ChatCompletion";
import {
type ChatCompletion,
type CompletionCreateParams,
} from "openai/resources/chat/completions";
const reqValidator = z.object({
model: z.string(),
@@ -16,11 +21,6 @@ const reqValidator = z.object({
const respValidator = z.object({
id: z.string(),
model: z.string(),
usage: z.object({
total_tokens: z.number(),
prompt_tokens: z.number(),
completion_tokens: z.number(),
}),
choices: z.array(
z.object({
finish_reason: z.string(),
@@ -76,7 +76,7 @@ export const externalApiRouter = createTRPCRouter({
originalLoggedCall: true,
},
orderBy: {
startTime: "desc",
requestedAt: "desc",
},
});
@@ -85,7 +85,7 @@ export const externalApiRouter = createTRPCRouter({
await prisma.loggedCall.create({
data: {
projectId: key.projectId,
startTime: new Date(input.startTime),
requestedAt: new Date(input.startTime),
cacheHit: true,
modelResponseId: existingResponse.id,
},
@@ -140,14 +140,20 @@ export const externalApiRouter = createTRPCRouter({
const newLoggedCallId = uuidv4();
const newModelResponseId = uuidv4();
const usage = respPayload.success ? respPayload.data.usage : undefined;
let usage;
if (reqPayload.success && respPayload.success) {
usage = modelProvider.getUsage(
input.reqPayload as CompletionCreateParams,
input.respPayload as ChatCompletion,
);
}
await prisma.$transaction([
prisma.loggedCall.create({
data: {
id: newLoggedCallId,
projectId: key.projectId,
startTime: new Date(input.startTime),
requestedAt: new Date(input.startTime),
cacheHit: false,
},
}),
@@ -155,20 +161,17 @@ export const externalApiRouter = createTRPCRouter({
data: {
id: newModelResponseId,
originalLoggedCallId: newLoggedCallId,
startTime: new Date(input.startTime),
endTime: new Date(input.endTime),
requestedAt: new Date(input.startTime),
receivedAt: new Date(input.endTime),
reqPayload: input.reqPayload as Prisma.InputJsonValue,
respPayload: input.respPayload as Prisma.InputJsonValue,
respStatus: input.respStatus,
error: input.error,
statusCode: input.respStatus,
errorMessage: input.error,
durationMs: input.endTime - input.startTime,
...(respPayload.success
? {
cacheKey: requestHash,
inputTokens: usage ? usage.prompt_tokens : undefined,
outputTokens: usage ? usage.completion_tokens : undefined,
}
: null),
cacheKey: respPayload.success ? requestHash : null,
inputTokens: usage?.inputTokens,
outputTokens: usage?.outputTokens,
cost: usage?.cost,
},
}),
// Avoid foreign key constraint error by updating the logged call after the model response is created
@@ -182,24 +185,22 @@ export const externalApiRouter = createTRPCRouter({
}),
]);
if (input.tags) {
const tagsToCreate = Object.entries(input.tags).map(([name, value]) => ({
loggedCallId: newLoggedCallId,
// sanitize tags
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
value,
}));
const tagsToCreate = Object.entries(input.tags ?? {}).map(([name, value]) => ({
loggedCallId: newLoggedCallId,
// sanitize tags
name: name.replaceAll(/[^a-zA-Z0-9_]/g, "_"),
value,
}));
if (reqPayload.success) {
tagsToCreate.push({
loggedCallId: newLoggedCallId,
name: "$model",
value: reqPayload.data.model,
});
}
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
if (reqPayload.success) {
tagsToCreate.push({
loggedCallId: newLoggedCallId,
name: "$model",
value: reqPayload.data.model,
});
}
await prisma.loggedCallTag.createMany({
data: tagsToCreate,
});
}),
});

View File

@@ -55,7 +55,7 @@ export const promptVariantsRouter = createTRPCRouter({
where: {
modelResponse: {
outdated: false,
output: { not: Prisma.AnyNull },
respPayload: { not: Prisma.AnyNull },
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
@@ -100,7 +100,7 @@ export const promptVariantsRouter = createTRPCRouter({
modelResponses: {
some: {
outdated: false,
output: {
respPayload: {
not: Prisma.AnyNull,
},
},
@@ -111,7 +111,7 @@ export const promptVariantsRouter = createTRPCRouter({
const overallTokens = await prisma.modelResponse.aggregate({
where: {
outdated: false,
output: {
respPayload: {
not: Prisma.AnyNull,
},
scenarioVariantCell: {
@@ -123,13 +123,13 @@ export const promptVariantsRouter = createTRPCRouter({
},
_sum: {
cost: true,
promptTokens: true,
completionTokens: true,
inputTokens: true,
outputTokens: true,
},
});
const promptTokens = overallTokens._sum?.promptTokens ?? 0;
const completionTokens = overallTokens._sum?.completionTokens ?? 0;
const inputTokens = overallTokens._sum?.inputTokens ?? 0;
const outputTokens = overallTokens._sum?.outputTokens ?? 0;
const awaitingEvals = !!evalResults.find(
(result) => result.totalCount < scenarioCount * evals.length,
@@ -137,8 +137,8 @@ export const promptVariantsRouter = createTRPCRouter({
return {
evalResults,
promptTokens,
completionTokens,
inputTokens,
outputTokens,
overallCost: overallTokens._sum?.cost ?? 0,
scenarioCount,
outputCount,