cache output evals

This commit is contained in:
Kyle Corbitt
2023-07-17 16:52:26 -07:00
parent 1ba18015bc
commit 011b12abb9
11 changed files with 244 additions and 144 deletions

View File

@@ -41,7 +41,6 @@ model PromptVariant {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
scenarioVariantCells ScenarioVariantCell[] scenarioVariantCells ScenarioVariantCell[]
EvaluationResult EvaluationResult[]
@@index([uiId]) @@index([uiId])
} }
@@ -124,6 +123,7 @@ model ModelOutput {
scenarioVariantCellId String @db.Uuid scenarioVariantCellId String @db.Uuid
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade) scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
outputEvaluation OutputEvaluation[]
@@unique([scenarioVariantCellId]) @@unique([scenarioVariantCellId])
@@index([inputHash]) @@index([inputHash])
@@ -146,25 +146,26 @@ model Evaluation {
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
EvaluationResult EvaluationResult[] OutputEvaluation OutputEvaluation[]
} }
model EvaluationResult { model OutputEvaluation {
id String @id @default(uuid()) @db.Uuid id String @id @default(uuid()) @db.Uuid
passCount Int // Number between 0 (fail) and 1 (pass)
failCount Int result Float
details String?
modelOutputId String @db.Uuid
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
evaluationId String @db.Uuid evaluationId String @db.Uuid
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade) evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
promptVariantId String @db.Uuid
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@@unique([evaluationId, promptVariantId]) @@unique([modelOutputId, evaluationId])
} }
// Necessary for Next auth // Necessary for Next auth

View File

@@ -40,7 +40,7 @@ export function EvaluationEditor(props: {
<Input <Input
size="sm" size="sm"
value={values.label} value={values.label}
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))} onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
/> />
</FormControl> </FormControl>
<FormControl flex={1}> <FormControl flex={1}>
@@ -125,6 +125,7 @@ export default function EditEvaluations() {
} }
await utils.evaluations.list.invalidate(); await utils.evaluations.list.invalidate();
await utils.promptVariants.stats.invalidate(); await utils.promptVariants.stats.invalidate();
await utils.scenarioVariantCells.get.invalidate();
}, []); }, []);
const onCancel = useCallback(() => { const onCancel = useCallback(() => {

View File

@@ -1,10 +1,7 @@
import { type ModelOutput } from "@prisma/client";
import { type SupportedModel } from "~/server/types"; import { type SupportedModel } from "~/server/types";
import { type Scenario } from "../types"; import { type Scenario } from "../types";
import { useExperiment } from "~/utils/hooks"; import { type RouterOutputs } from "~/utils/api";
import { api } from "~/utils/api";
import { calculateTokenCost } from "~/utils/calculateTokenCost"; import { calculateTokenCost } from "~/utils/calculateTokenCost";
import { evaluateOutput } from "~/server/utils/evaluateOutput";
import { HStack, Icon, Text } from "@chakra-ui/react"; import { HStack, Icon, Text } from "@chakra-ui/react";
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs"; import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
import { CostTooltip } from "~/components/tooltip/CostTooltip"; import { CostTooltip } from "~/components/tooltip/CostTooltip";
@@ -15,16 +12,14 @@ const SHOW_TIME = true;
export const OutputStats = ({ export const OutputStats = ({
model, model,
modelOutput, modelOutput,
scenario,
}: { }: {
model: SupportedModel | string | null; model: SupportedModel | string | null;
modelOutput: ModelOutput; modelOutput: NonNullable<
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
>;
scenario: Scenario; scenario: Scenario;
}) => { }) => {
const timeToComplete = modelOutput.timeToComplete; const timeToComplete = modelOutput.timeToComplete;
const experiment = useExperiment();
const evals =
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
const promptTokens = modelOutput.promptTokens; const promptTokens = modelOutput.promptTokens;
const completionTokens = modelOutput.completionTokens; const completionTokens = modelOutput.completionTokens;
@@ -38,11 +33,11 @@ export const OutputStats = ({
return ( return (
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}> <HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
<HStack flex={1}> <HStack flex={1}>
{evals.map((evaluation) => { {modelOutput.outputEvaluation.map((evaluation) => {
const passed = evaluateOutput(modelOutput, scenario, evaluation); const passed = evaluation.result > 0.5;
return ( return (
<HStack spacing={0} key={evaluation.id}> <HStack spacing={0} key={evaluation.id}>
<Text>{evaluation.label}</Text> <Text>{evaluation.evaluation.label}</Text>
<Icon <Icon
as={passed ? BsCheck : BsX} as={passed ? BsCheck : BsX}
color={passed ? "green.500" : "red.500"} color={passed ? "green.500" : "red.500"}

View File

@@ -44,10 +44,10 @@ export default function VariantStats(props: { variant: PromptVariant }) {
)} )}
<HStack px={cellPadding.x} py={cellPadding.y}> <HStack px={cellPadding.x} py={cellPadding.y}>
{data.evalResults.map((result) => { {data.evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount); const passedFrac = result.passCount / result.totalCount;
return ( return (
<HStack key={result.id}> <HStack key={result.id}>
<Text>{result.evaluation.label}</Text> <Text>{result.label}</Text>
<Text color={scale(passedFrac).hex()} fontWeight="bold"> <Text color={scale(passedFrac).hex()} fontWeight="bold">
{(passedFrac * 100).toFixed(1)}% {(passedFrac * 100).toFixed(1)}%
</Text> </Text>

View File

@@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client";
import { z } from "zod"; import { z } from "zod";
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { reevaluateEvaluation } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
export const evaluationsRouter = createTRPCRouter({ export const evaluationsRouter = createTRPCRouter({
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => { list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
@@ -24,7 +24,7 @@ export const evaluationsRouter = createTRPCRouter({
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
const evaluation = await prisma.evaluation.create({ await prisma.evaluation.create({
data: { data: {
experimentId: input.experimentId, experimentId: input.experimentId,
label: input.label, label: input.label,
@@ -32,7 +32,10 @@ export const evaluationsRouter = createTRPCRouter({
evalType: input.evalType, evalType: input.evalType,
}, },
}); });
await reevaluateEvaluation(evaluation);
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
// to kick off a background job or something instead
await runAllEvals(input.experimentId);
}), }),
update: publicProcedure update: publicProcedure
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
z.object({ z.object({
id: z.string(), id: z.string(),
updates: z.object({ updates: z.object({
name: z.string().optional(), label: z.string().optional(),
value: z.string().optional(), value: z.string().optional(),
evalType: z.nativeEnum(EvalType).optional(), evalType: z.nativeEnum(EvalType).optional(),
}), }),
}), }),
) )
.mutation(async ({ input }) => { .mutation(async ({ input }) => {
await prisma.evaluation.update({ const evaluation = await prisma.evaluation.update({
where: { id: input.id }, where: { id: input.id },
data: { data: {
label: input.updates.name, label: input.updates.label,
value: input.updates.value, value: input.updates.value,
evalType: input.updates.evalType, evalType: input.updates.evalType,
}, },
}); });
await reevaluateEvaluation(
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }), await prisma.outputEvaluation.deleteMany({
); where: {
evaluationId: evaluation.id,
},
});
// Re-run all evals. Other eval results will already be cached, so this
// should only re-run the updated one.
await runAllEvals(evaluation.experimentId);
}), }),
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => { delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {

View File

@@ -32,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`); throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
} }
const evalResults = await prisma.evaluationResult.findMany({ const outputEvals = await prisma.outputEvaluation.groupBy({
where: { by: ["evaluationId"],
promptVariantId: input.variantId, _sum: {
result: true,
}, },
include: { evaluation: true }, _count: {
id: true,
},
where: {
modelOutput: {
scenarioVariantCell: {
promptVariant: {
id: input.variantId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
},
});
const evals = await prisma.evaluation.findMany({
where: {
experimentId: variant.experimentId,
},
});
const evalResults = evals.map((evalItem) => {
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
return {
id: evalItem.id,
label: evalItem.label,
passCount: evalResult?._sum?.result ?? 0,
totalCount: evalResult?._count?.id ?? 1,
};
}); });
const scenarioCount = await prisma.testScenario.count({ const scenarioCount = await prisma.testScenario.count({
@@ -50,7 +82,7 @@ export const promptVariantsRouter = createTRPCRouter({
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenario: { visible: true }, testScenario: { visible: true },
modelOutput: { modelOutput: {
isNot: null, is: {},
}, },
}, },
}); });

View File

@@ -21,7 +21,17 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
}, },
}, },
include: { include: {
modelOutput: true, modelOutput: {
include: {
outputEvaluation: {
include: {
evaluation: {
select: { label: true },
},
},
},
},
},
}, },
}); });
}), }),

View File

@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
import { prisma } from "~/server/db"; import { prisma } from "~/server/db";
import { autogenerateScenarioValues } from "../autogen"; import { autogenerateScenarioValues } from "../autogen";
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated"; import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
import { reevaluateAll } from "~/server/utils/evaluations"; import { runAllEvals } from "~/server/utils/evaluations";
import { generateNewCell } from "~/server/utils/generateNewCell"; import { generateNewCell } from "~/server/utils/generateNewCell";
export const scenariosRouter = createTRPCRouter({ export const scenariosRouter = createTRPCRouter({
@@ -73,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
}); });
// Reevaluate all evaluations now that this scenario is hidden // Reevaluate all evaluations now that this scenario is hidden
await reevaluateAll(hiddenScenario.experimentId); await runAllEvals(hiddenScenario.experimentId);
return hiddenScenario; return hiddenScenario;
}), }),

View File

@@ -6,7 +6,7 @@ import { type JSONSerializable } from "../types";
import { sleep } from "../utils/sleep"; import { sleep } from "../utils/sleep";
import { shouldStream } from "../utils/shouldStream"; import { shouldStream } from "../utils/shouldStream";
import { generateChannel } from "~/utils/generateChannel"; import { generateChannel } from "~/utils/generateChannel";
import { reevaluateVariant } from "../utils/evaluations"; import { runEvalsForOutput } from "../utils/evaluations";
import { constructPrompt } from "../utils/constructPrompt"; import { constructPrompt } from "../utils/constructPrompt";
import { type CompletionCreateParams } from "openai/resources/chat"; import { type CompletionCreateParams } from "openai/resources/chat";
import { type Prisma } from "@prisma/client"; import { type Prisma } from "@prisma/client";
@@ -148,5 +148,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
}, },
}); });
await reevaluateVariant(cell.promptVariantId); if (modelOutput) {
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
}
}); });

View File

@@ -1,105 +1,154 @@
import { type ModelOutput, type Evaluation } from "@prisma/client"; import { type ModelOutput, type Evaluation } from "@prisma/client";
import { prisma } from "../db"; import { prisma } from "../db";
import { evaluateOutput } from "./evaluateOutput"; import { runOneEval } from "./runOneEval";
import { type Scenario } from "~/components/OutputsTable/types";
export const reevaluateVariant = async (variantId: string) => { const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
const variant = await prisma.promptVariant.findUnique({ const result = runOneEval(evaluation, scenario, modelOutput);
where: { id: variantId }, return await prisma.outputEvaluation.upsert({
});
if (!variant) return;
const evaluations = await prisma.evaluation.findMany({
where: { experimentId: variant.experimentId },
});
const cells = await prisma.scenarioVariantCell.findMany({
where: { where: {
promptVariantId: variantId, modelOutputId_evaluationId: {
retrievalStatus: "COMPLETE", modelOutputId: modelOutput.id,
testScenario: { visible: true }, evaluationId: evaluation.id,
modelOutput: { isNot: null }, },
},
create: {
modelOutputId: modelOutput.id,
evaluationId: evaluation.id,
result,
},
update: {
result,
}, },
include: { testScenario: true, modelOutput: true },
}); });
await Promise.all(
evaluations.map(async (evaluation) => {
const passCount = cells.filter((cell) =>
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length;
const failCount = cells.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variantId,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variantId,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
}; };
export const reevaluateEvaluation = async (evaluation: Evaluation) => { export const runEvalsForOutput = async (
const variants = await prisma.promptVariant.findMany({ experimentId: string,
where: { experimentId: evaluation.experimentId, visible: true }, scenario: Scenario,
}); modelOutput: ModelOutput,
) => {
const cells = await prisma.scenarioVariantCell.findMany({
where: {
promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true },
statusCode: { notIn: [429] },
modelOutput: { isNot: null },
},
include: { testScenario: true, modelOutput: true },
});
await Promise.all(
variants.map(async (variant) => {
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
const passCount = variantCells.filter((cell) =>
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
).length;
const failCount = variantCells.length - passCount;
await prisma.evaluationResult.upsert({
where: {
evaluationId_promptVariantId: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
},
},
create: {
evaluationId: evaluation.id,
promptVariantId: variant.id,
passCount,
failCount,
},
update: {
passCount,
failCount,
},
});
}),
);
};
export const reevaluateAll = async (experimentId: string) => {
const evaluations = await prisma.evaluation.findMany({ const evaluations = await prisma.evaluation.findMany({
where: { experimentId }, where: { experimentId },
}); });
await Promise.all(evaluations.map(reevaluateEvaluation)); await Promise.all(
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
);
// const cells = await prisma.scenarioVariantCell.findMany({
// where: {
// promptVariantId: variantId,
// retrievalStatus: "COMPLETE",
// testScenario: { visible: true },
// },
// include: { testScenario: true, modelOutput: { include: { OutputEvaluation: true } } },
// });
// await Promise.all(
// evaluations.map(async (evaluation) => {
// const passCount = cells.filter((cell) =>
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
// ).length;
// const failCount = cells.length - passCount;
// await prisma.evaluationResult.upsert({
// where: {
// evaluationId_promptVariantId: {
// evaluationId: evaluation.id,
// promptVariantId: variantId,
// },
// },
// create: {
// evaluationId: evaluation.id,
// promptVariantId: variantId,
// passCount,
// failCount,
// },
// update: {
// passCount,
// failCount,
// },
// });
// }),
// );
};
export const runAllEvals = async (experimentId: string) => {
const outputs = await prisma.modelOutput.findMany({
where: {
scenarioVariantCell: {
promptVariant: {
experimentId,
visible: true,
},
testScenario: {
visible: true,
},
},
},
include: {
scenarioVariantCell: {
include: {
testScenario: true,
},
},
outputEvaluation: true,
},
});
const evals = await prisma.evaluation.findMany({
where: { experimentId },
});
await Promise.all(
outputs.map(async (output) => {
const unrunEvals = evals.filter(
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
);
await Promise.all(
unrunEvals.map(async (evaluation) => {
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
}),
);
}),
);
// const cells = await prisma.scenarioVariantCell.findMany({
// where: {
// promptVariantId: { in: variants.map((v) => v.id) },
// testScenario: { visible: true },
// statusCode: { notIn: [429] },
// },
// include: { testScenario: true, modelOutput: true },
// });
// await Promise.all(
// variants.map(async (variant) => {
// const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
// const passCount = variantCells.filter((cell) =>
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
// ).length;
// const failCount = variantCells.length - passCount;
// await prisma.evaluationResult.upsert({
// where: {
// evaluationId_promptVariantId: {
// evaluationId: evaluation.id,
// promptVariantId: variant.id,
// },
// },
// create: {
// evaluationId: evaluation.id,
// promptVariantId: variant.id,
// passCount,
// failCount,
// },
// update: {
// passCount,
// failCount,
// },
// });
// }),
// );
}; };

View File

@@ -2,30 +2,31 @@ import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/cl
import { type ChatCompletion } from "openai/resources/chat"; import { type ChatCompletion } from "openai/resources/chat";
import { type VariableMap, fillTemplate } from "./fillTemplate"; import { type VariableMap, fillTemplate } from "./fillTemplate";
export const evaluateOutput = ( export const runOneEval = (
modelOutput: ModelOutput,
scenario: TestScenario,
evaluation: Evaluation, evaluation: Evaluation,
): boolean => { scenario: TestScenario,
modelOutput: ModelOutput,
): number => {
const output = modelOutput.output as unknown as ChatCompletion; const output = modelOutput.output as unknown as ChatCompletion;
const message = output?.choices?.[0]?.message; const message = output?.choices?.[0]?.message;
if (!message) return false; if (!message) return 0;
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call); const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap); const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
let match; let result;
switch (evaluation.evalType) { switch (evaluation.evalType) {
case "CONTAINS": case "CONTAINS":
match = stringifiedMessage.match(matchRegex) !== null; result = stringifiedMessage.match(matchRegex) !== null ? 1 : 0;
break; break;
case "DOES_NOT_CONTAIN": case "DOES_NOT_CONTAIN":
match = stringifiedMessage.match(matchRegex) === null; result = stringifiedMessage.match(matchRegex) === null ? 1 : 0;
break; break;
} }
return match; return result;
}; };