From 011b12abb9f18a2283578dbb68fc04876cf4a1a9 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Mon, 17 Jul 2023 16:52:26 -0700 Subject: [PATCH] cache output evals --- prisma/schema.prisma | 19 +- .../OutputsTable/EditEvaluations.tsx | 3 +- .../OutputsTable/OutputCell/OutputStats.tsx | 19 +- src/components/OutputsTable/VariantStats.tsx | 4 +- src/server/api/routers/evaluations.router.ts | 27 +- .../api/routers/promptVariants.router.ts | 42 +++- .../routers/scenarioVariantCells.router.ts | 12 +- src/server/api/routers/scenarios.router.ts | 4 +- src/server/tasks/queryLLM.task.ts | 6 +- src/server/utils/evaluations.ts | 233 +++++++++++------- .../{evaluateOutput.ts => runOneEval.ts} | 19 +- 11 files changed, 244 insertions(+), 144 deletions(-) rename src/server/utils/{evaluateOutput.ts => runOneEval.ts} (74%) diff --git a/prisma/schema.prisma b/prisma/schema.prisma index bf05a85..119a78d 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -41,7 +41,6 @@ model PromptVariant { createdAt DateTime @default(now()) updatedAt DateTime @updatedAt scenarioVariantCells ScenarioVariantCell[] - EvaluationResult EvaluationResult[] @@index([uiId]) } @@ -124,6 +123,7 @@ model ModelOutput { scenarioVariantCellId String @db.Uuid scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) + outputEvaluation OutputEvaluation[] @@unique([scenarioVariantCellId]) @@index([inputHash]) @@ -146,25 +146,26 @@ model Evaluation { createdAt DateTime @default(now()) updatedAt DateTime @updatedAt - EvaluationResult EvaluationResult[] + OutputEvaluation OutputEvaluation[] } -model EvaluationResult { +model OutputEvaluation { id String @id @default(uuid()) @db.Uuid - passCount Int - failCount Int + // Number between 0 (fail) and 1 (pass) + result Float + details String? + + modelOutputId String @db.Uuid + modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade) evaluationId String @db.Uuid evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade) - promptVariantId String @db.Uuid - promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade) - createdAt DateTime @default(now()) updatedAt DateTime @updatedAt - @@unique([evaluationId, promptVariantId]) + @@unique([modelOutputId, evaluationId]) } // Necessary for Next auth diff --git a/src/components/OutputsTable/EditEvaluations.tsx b/src/components/OutputsTable/EditEvaluations.tsx index 7d1a44c..a672064 100644 --- a/src/components/OutputsTable/EditEvaluations.tsx +++ b/src/components/OutputsTable/EditEvaluations.tsx @@ -40,7 +40,7 @@ export function EvaluationEditor(props: { setValues((values) => ({ ...values, name: e.target.value }))} + onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))} /> @@ -125,6 +125,7 @@ export default function EditEvaluations() { } await utils.evaluations.list.invalidate(); await utils.promptVariants.stats.invalidate(); + await utils.scenarioVariantCells.get.invalidate(); }, []); const onCancel = useCallback(() => { diff --git a/src/components/OutputsTable/OutputCell/OutputStats.tsx b/src/components/OutputsTable/OutputCell/OutputStats.tsx index 24a92b0..5256168 100644 --- a/src/components/OutputsTable/OutputCell/OutputStats.tsx +++ b/src/components/OutputsTable/OutputCell/OutputStats.tsx @@ -1,10 +1,7 @@ -import { type ModelOutput } from "@prisma/client"; import { type SupportedModel } from "~/server/types"; import { type Scenario } from "../types"; -import { useExperiment } from "~/utils/hooks"; -import { api } from "~/utils/api"; +import { type RouterOutputs } from "~/utils/api"; import { calculateTokenCost } from "~/utils/calculateTokenCost"; -import { evaluateOutput } from "~/server/utils/evaluateOutput"; import { HStack, Icon, Text } from "@chakra-ui/react"; import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; import { CostTooltip } from "~/components/tooltip/CostTooltip"; @@ -15,16 +12,14 @@ const SHOW_TIME = true; export const OutputStats = ({ model, modelOutput, - scenario, }: { model: SupportedModel | string | null; - modelOutput: ModelOutput; + modelOutput: NonNullable< + NonNullable["modelOutput"] + >; scenario: Scenario; }) => { const timeToComplete = modelOutput.timeToComplete; - const experiment = useExperiment(); - const evals = - api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? []; const promptTokens = modelOutput.promptTokens; const completionTokens = modelOutput.completionTokens; @@ -38,11 +33,11 @@ export const OutputStats = ({ return ( - {evals.map((evaluation) => { - const passed = evaluateOutput(modelOutput, scenario, evaluation); + {modelOutput.outputEvaluation.map((evaluation) => { + const passed = evaluation.result > 0.5; return ( - {evaluation.label} + {evaluation.evaluation.label} {data.evalResults.map((result) => { - const passedFrac = result.passCount / (result.passCount + result.failCount); + const passedFrac = result.passCount / result.totalCount; return ( - {result.evaluation.label} + {result.label} {(passedFrac * 100).toFixed(1)}% diff --git a/src/server/api/routers/evaluations.router.ts b/src/server/api/routers/evaluations.router.ts index 99c1a8c..7ee0d12 100644 --- a/src/server/api/routers/evaluations.router.ts +++ b/src/server/api/routers/evaluations.router.ts @@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client"; import { z } from "zod"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; -import { reevaluateEvaluation } from "~/server/utils/evaluations"; +import { runAllEvals } from "~/server/utils/evaluations"; export const evaluationsRouter = createTRPCRouter({ list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { @@ -24,7 +24,7 @@ export const evaluationsRouter = createTRPCRouter({ }), ) .mutation(async ({ input }) => { - const evaluation = await prisma.evaluation.create({ + await prisma.evaluation.create({ data: { experimentId: input.experimentId, label: input.label, @@ -32,7 +32,10 @@ export const evaluationsRouter = createTRPCRouter({ evalType: input.evalType, }, }); - await reevaluateEvaluation(evaluation); + + // TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need + // to kick off a background job or something instead + await runAllEvals(input.experimentId); }), update: publicProcedure @@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({ z.object({ id: z.string(), updates: z.object({ - name: z.string().optional(), + label: z.string().optional(), value: z.string().optional(), evalType: z.nativeEnum(EvalType).optional(), }), }), ) .mutation(async ({ input }) => { - await prisma.evaluation.update({ + const evaluation = await prisma.evaluation.update({ where: { id: input.id }, data: { - label: input.updates.name, + label: input.updates.label, value: input.updates.value, evalType: input.updates.evalType, }, }); - await reevaluateEvaluation( - await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }), - ); + + await prisma.outputEvaluation.deleteMany({ + where: { + evaluationId: evaluation.id, + }, + }); + // Re-run all evals. Other eval results will already be cached, so this + // should only re-run the updated one. + await runAllEvals(evaluation.experimentId); }), delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { diff --git a/src/server/api/routers/promptVariants.router.ts b/src/server/api/routers/promptVariants.router.ts index ece4815..20f50d0 100644 --- a/src/server/api/routers/promptVariants.router.ts +++ b/src/server/api/routers/promptVariants.router.ts @@ -32,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({ throw new Error(`Prompt Variant with id ${input.variantId} does not exist`); } - const evalResults = await prisma.evaluationResult.findMany({ - where: { - promptVariantId: input.variantId, + const outputEvals = await prisma.outputEvaluation.groupBy({ + by: ["evaluationId"], + _sum: { + result: true, }, - include: { evaluation: true }, + _count: { + id: true, + }, + where: { + modelOutput: { + scenarioVariantCell: { + promptVariant: { + id: input.variantId, + visible: true, + }, + testScenario: { + visible: true, + }, + }, + }, + }, + }); + + const evals = await prisma.evaluation.findMany({ + where: { + experimentId: variant.experimentId, + }, + }); + + const evalResults = evals.map((evalItem) => { + const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id); + return { + id: evalItem.id, + label: evalItem.label, + passCount: evalResult?._sum?.result ?? 0, + totalCount: evalResult?._count?.id ?? 1, + }; }); const scenarioCount = await prisma.testScenario.count({ @@ -50,7 +82,7 @@ export const promptVariantsRouter = createTRPCRouter({ promptVariantId: input.variantId, testScenario: { visible: true }, modelOutput: { - isNot: null, + is: {}, }, }, }); diff --git a/src/server/api/routers/scenarioVariantCells.router.ts b/src/server/api/routers/scenarioVariantCells.router.ts index 09e1172..b07657e 100644 --- a/src/server/api/routers/scenarioVariantCells.router.ts +++ b/src/server/api/routers/scenarioVariantCells.router.ts @@ -21,7 +21,17 @@ export const scenarioVariantCellsRouter = createTRPCRouter({ }, }, include: { - modelOutput: true, + modelOutput: { + include: { + outputEvaluation: { + include: { + evaluation: { + select: { label: true }, + }, + }, + }, + }, + }, }, }); }), diff --git a/src/server/api/routers/scenarios.router.ts b/src/server/api/routers/scenarios.router.ts index 5075ae7..0ddfb0b 100644 --- a/src/server/api/routers/scenarios.router.ts +++ b/src/server/api/routers/scenarios.router.ts @@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { prisma } from "~/server/db"; import { autogenerateScenarioValues } from "../autogen"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; -import { reevaluateAll } from "~/server/utils/evaluations"; +import { runAllEvals } from "~/server/utils/evaluations"; import { generateNewCell } from "~/server/utils/generateNewCell"; export const scenariosRouter = createTRPCRouter({ @@ -73,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({ }); // Reevaluate all evaluations now that this scenario is hidden - await reevaluateAll(hiddenScenario.experimentId); + await runAllEvals(hiddenScenario.experimentId); return hiddenScenario; }), diff --git a/src/server/tasks/queryLLM.task.ts b/src/server/tasks/queryLLM.task.ts index 541d17b..1eb94c7 100644 --- a/src/server/tasks/queryLLM.task.ts +++ b/src/server/tasks/queryLLM.task.ts @@ -6,7 +6,7 @@ import { type JSONSerializable } from "../types"; import { sleep } from "../utils/sleep"; import { shouldStream } from "../utils/shouldStream"; import { generateChannel } from "~/utils/generateChannel"; -import { reevaluateVariant } from "../utils/evaluations"; +import { runEvalsForOutput } from "../utils/evaluations"; import { constructPrompt } from "../utils/constructPrompt"; import { type CompletionCreateParams } from "openai/resources/chat"; import { type Prisma } from "@prisma/client"; @@ -148,5 +148,7 @@ export const queryLLM = defineTask("queryLLM", async (task) => { }, }); - await reevaluateVariant(cell.promptVariantId); + if (modelOutput) { + await runEvalsForOutput(variant.experimentId, scenario, modelOutput); + } }); diff --git a/src/server/utils/evaluations.ts b/src/server/utils/evaluations.ts index 3b24334..8740b88 100644 --- a/src/server/utils/evaluations.ts +++ b/src/server/utils/evaluations.ts @@ -1,105 +1,154 @@ import { type ModelOutput, type Evaluation } from "@prisma/client"; import { prisma } from "../db"; -import { evaluateOutput } from "./evaluateOutput"; +import { runOneEval } from "./runOneEval"; +import { type Scenario } from "~/components/OutputsTable/types"; -export const reevaluateVariant = async (variantId: string) => { - const variant = await prisma.promptVariant.findUnique({ - where: { id: variantId }, - }); - if (!variant) return; - - const evaluations = await prisma.evaluation.findMany({ - where: { experimentId: variant.experimentId }, - }); - - const cells = await prisma.scenarioVariantCell.findMany({ +const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => { + const result = runOneEval(evaluation, scenario, modelOutput); + return await prisma.outputEvaluation.upsert({ where: { - promptVariantId: variantId, - retrievalStatus: "COMPLETE", - testScenario: { visible: true }, - modelOutput: { isNot: null }, + modelOutputId_evaluationId: { + modelOutputId: modelOutput.id, + evaluationId: evaluation.id, + }, + }, + create: { + modelOutputId: modelOutput.id, + evaluationId: evaluation.id, + result, + }, + update: { + result, }, - include: { testScenario: true, modelOutput: true }, }); - - await Promise.all( - evaluations.map(async (evaluation) => { - const passCount = cells.filter((cell) => - evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation), - ).length; - const failCount = cells.length - passCount; - - await prisma.evaluationResult.upsert({ - where: { - evaluationId_promptVariantId: { - evaluationId: evaluation.id, - promptVariantId: variantId, - }, - }, - create: { - evaluationId: evaluation.id, - promptVariantId: variantId, - passCount, - failCount, - }, - update: { - passCount, - failCount, - }, - }); - }), - ); }; -export const reevaluateEvaluation = async (evaluation: Evaluation) => { - const variants = await prisma.promptVariant.findMany({ - where: { experimentId: evaluation.experimentId, visible: true }, - }); - - const cells = await prisma.scenarioVariantCell.findMany({ - where: { - promptVariantId: { in: variants.map((v) => v.id) }, - testScenario: { visible: true }, - statusCode: { notIn: [429] }, - modelOutput: { isNot: null }, - }, - include: { testScenario: true, modelOutput: true }, - }); - - await Promise.all( - variants.map(async (variant) => { - const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id); - const passCount = variantCells.filter((cell) => - evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation), - ).length; - const failCount = variantCells.length - passCount; - - await prisma.evaluationResult.upsert({ - where: { - evaluationId_promptVariantId: { - evaluationId: evaluation.id, - promptVariantId: variant.id, - }, - }, - create: { - evaluationId: evaluation.id, - promptVariantId: variant.id, - passCount, - failCount, - }, - update: { - passCount, - failCount, - }, - }); - }), - ); -}; - -export const reevaluateAll = async (experimentId: string) => { +export const runEvalsForOutput = async ( + experimentId: string, + scenario: Scenario, + modelOutput: ModelOutput, +) => { const evaluations = await prisma.evaluation.findMany({ where: { experimentId }, }); - await Promise.all(evaluations.map(reevaluateEvaluation)); + await Promise.all( + evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)), + ); + + // const cells = await prisma.scenarioVariantCell.findMany({ + // where: { + // promptVariantId: variantId, + // retrievalStatus: "COMPLETE", + // testScenario: { visible: true }, + // }, + // include: { testScenario: true, modelOutput: { include: { OutputEvaluation: true } } }, + // }); + + // await Promise.all( + // evaluations.map(async (evaluation) => { + // const passCount = cells.filter((cell) => + // runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation), + // ).length; + // const failCount = cells.length - passCount; + + // await prisma.evaluationResult.upsert({ + // where: { + // evaluationId_promptVariantId: { + // evaluationId: evaluation.id, + // promptVariantId: variantId, + // }, + // }, + // create: { + // evaluationId: evaluation.id, + // promptVariantId: variantId, + // passCount, + // failCount, + // }, + // update: { + // passCount, + // failCount, + // }, + // }); + // }), + // ); +}; + +export const runAllEvals = async (experimentId: string) => { + const outputs = await prisma.modelOutput.findMany({ + where: { + scenarioVariantCell: { + promptVariant: { + experimentId, + visible: true, + }, + testScenario: { + visible: true, + }, + }, + }, + include: { + scenarioVariantCell: { + include: { + testScenario: true, + }, + }, + outputEvaluation: true, + }, + }); + const evals = await prisma.evaluation.findMany({ + where: { experimentId }, + }); + + await Promise.all( + outputs.map(async (output) => { + const unrunEvals = evals.filter( + (evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id), + ); + + await Promise.all( + unrunEvals.map(async (evaluation) => { + await saveResult(evaluation, output.scenarioVariantCell.testScenario, output); + }), + ); + }), + ); + + // const cells = await prisma.scenarioVariantCell.findMany({ + // where: { + // promptVariantId: { in: variants.map((v) => v.id) }, + // testScenario: { visible: true }, + // statusCode: { notIn: [429] }, + // }, + // include: { testScenario: true, modelOutput: true }, + // }); + + // await Promise.all( + // variants.map(async (variant) => { + // const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id); + // const passCount = variantCells.filter((cell) => + // runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation), + // ).length; + // const failCount = variantCells.length - passCount; + + // await prisma.evaluationResult.upsert({ + // where: { + // evaluationId_promptVariantId: { + // evaluationId: evaluation.id, + // promptVariantId: variant.id, + // }, + // }, + // create: { + // evaluationId: evaluation.id, + // promptVariantId: variant.id, + // passCount, + // failCount, + // }, + // update: { + // passCount, + // failCount, + // }, + // }); + // }), + // ); }; diff --git a/src/server/utils/evaluateOutput.ts b/src/server/utils/runOneEval.ts similarity index 74% rename from src/server/utils/evaluateOutput.ts rename to src/server/utils/runOneEval.ts index accf9df..619f271 100644 --- a/src/server/utils/evaluateOutput.ts +++ b/src/server/utils/runOneEval.ts @@ -2,30 +2,31 @@ import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/cl import { type ChatCompletion } from "openai/resources/chat"; import { type VariableMap, fillTemplate } from "./fillTemplate"; -export const evaluateOutput = ( - modelOutput: ModelOutput, - scenario: TestScenario, +export const runOneEval = ( evaluation: Evaluation, -): boolean => { + scenario: TestScenario, + modelOutput: ModelOutput, +): number => { const output = modelOutput.output as unknown as ChatCompletion; + const message = output?.choices?.[0]?.message; - if (!message) return false; + if (!message) return 0; const stringifiedMessage = message.content ?? JSON.stringify(message.function_call); const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap); - let match; + let result; switch (evaluation.evalType) { case "CONTAINS": - match = stringifiedMessage.match(matchRegex) !== null; + result = stringifiedMessage.match(matchRegex) !== null ? 1 : 0; break; case "DOES_NOT_CONTAIN": - match = stringifiedMessage.match(matchRegex) === null; + result = stringifiedMessage.match(matchRegex) === null ? 1 : 0; break; } - return match; + return result; };