cache output evals
This commit is contained in:
@@ -41,7 +41,6 @@ model PromptVariant {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
scenarioVariantCells ScenarioVariantCell[]
|
||||
EvaluationResult EvaluationResult[]
|
||||
|
||||
@@index([uiId])
|
||||
}
|
||||
@@ -124,6 +123,7 @@ model ModelOutput {
|
||||
|
||||
scenarioVariantCellId String @db.Uuid
|
||||
scenarioVariantCell ScenarioVariantCell @relation(fields: [scenarioVariantCellId], references: [id], onDelete: Cascade)
|
||||
outputEvaluation OutputEvaluation[]
|
||||
|
||||
@@unique([scenarioVariantCellId])
|
||||
@@index([inputHash])
|
||||
@@ -146,25 +146,26 @@ model Evaluation {
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
EvaluationResult EvaluationResult[]
|
||||
OutputEvaluation OutputEvaluation[]
|
||||
}
|
||||
|
||||
model EvaluationResult {
|
||||
model OutputEvaluation {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
|
||||
passCount Int
|
||||
failCount Int
|
||||
// Number between 0 (fail) and 1 (pass)
|
||||
result Float
|
||||
details String?
|
||||
|
||||
modelOutputId String @db.Uuid
|
||||
modelOutput ModelOutput @relation(fields: [modelOutputId], references: [id], onDelete: Cascade)
|
||||
|
||||
evaluationId String @db.Uuid
|
||||
evaluation Evaluation @relation(fields: [evaluationId], references: [id], onDelete: Cascade)
|
||||
|
||||
promptVariantId String @db.Uuid
|
||||
promptVariant PromptVariant @relation(fields: [promptVariantId], references: [id], onDelete: Cascade)
|
||||
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
@@unique([evaluationId, promptVariantId])
|
||||
@@unique([modelOutputId, evaluationId])
|
||||
}
|
||||
|
||||
// Necessary for Next auth
|
||||
|
||||
@@ -40,7 +40,7 @@ export function EvaluationEditor(props: {
|
||||
<Input
|
||||
size="sm"
|
||||
value={values.label}
|
||||
onChange={(e) => setValues((values) => ({ ...values, name: e.target.value }))}
|
||||
onChange={(e) => setValues((values) => ({ ...values, label: e.target.value }))}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl flex={1}>
|
||||
@@ -125,6 +125,7 @@ export default function EditEvaluations() {
|
||||
}
|
||||
await utils.evaluations.list.invalidate();
|
||||
await utils.promptVariants.stats.invalidate();
|
||||
await utils.scenarioVariantCells.get.invalidate();
|
||||
}, []);
|
||||
|
||||
const onCancel = useCallback(() => {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import { type ModelOutput } from "@prisma/client";
|
||||
import { type SupportedModel } from "~/server/types";
|
||||
import { type Scenario } from "../types";
|
||||
import { useExperiment } from "~/utils/hooks";
|
||||
import { api } from "~/utils/api";
|
||||
import { type RouterOutputs } from "~/utils/api";
|
||||
import { calculateTokenCost } from "~/utils/calculateTokenCost";
|
||||
import { evaluateOutput } from "~/server/utils/evaluateOutput";
|
||||
import { HStack, Icon, Text } from "@chakra-ui/react";
|
||||
import { BsCheck, BsClock, BsCurrencyDollar, BsX } from "react-icons/bs";
|
||||
import { CostTooltip } from "~/components/tooltip/CostTooltip";
|
||||
@@ -15,16 +12,14 @@ const SHOW_TIME = true;
|
||||
export const OutputStats = ({
|
||||
model,
|
||||
modelOutput,
|
||||
scenario,
|
||||
}: {
|
||||
model: SupportedModel | string | null;
|
||||
modelOutput: ModelOutput;
|
||||
modelOutput: NonNullable<
|
||||
NonNullable<RouterOutputs["scenarioVariantCells"]["get"]>["modelOutput"]
|
||||
>;
|
||||
scenario: Scenario;
|
||||
}) => {
|
||||
const timeToComplete = modelOutput.timeToComplete;
|
||||
const experiment = useExperiment();
|
||||
const evals =
|
||||
api.evaluations.list.useQuery({ experimentId: experiment.data?.id ?? "" }).data ?? [];
|
||||
|
||||
const promptTokens = modelOutput.promptTokens;
|
||||
const completionTokens = modelOutput.completionTokens;
|
||||
@@ -38,11 +33,11 @@ export const OutputStats = ({
|
||||
return (
|
||||
<HStack align="center" color="gray.500" fontSize="2xs" mt={{ base: 0, md: 1 }}>
|
||||
<HStack flex={1}>
|
||||
{evals.map((evaluation) => {
|
||||
const passed = evaluateOutput(modelOutput, scenario, evaluation);
|
||||
{modelOutput.outputEvaluation.map((evaluation) => {
|
||||
const passed = evaluation.result > 0.5;
|
||||
return (
|
||||
<HStack spacing={0} key={evaluation.id}>
|
||||
<Text>{evaluation.label}</Text>
|
||||
<Text>{evaluation.evaluation.label}</Text>
|
||||
<Icon
|
||||
as={passed ? BsCheck : BsX}
|
||||
color={passed ? "green.500" : "red.500"}
|
||||
|
||||
@@ -44,10 +44,10 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
||||
)}
|
||||
<HStack px={cellPadding.x} py={cellPadding.y}>
|
||||
{data.evalResults.map((result) => {
|
||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
||||
const passedFrac = result.passCount / result.totalCount;
|
||||
return (
|
||||
<HStack key={result.id}>
|
||||
<Text>{result.evaluation.label}</Text>
|
||||
<Text>{result.label}</Text>
|
||||
<Text color={scale(passedFrac).hex()} fontWeight="bold">
|
||||
{(passedFrac * 100).toFixed(1)}%
|
||||
</Text>
|
||||
|
||||
@@ -2,7 +2,7 @@ import { EvalType } from "@prisma/client";
|
||||
import { z } from "zod";
|
||||
import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { reevaluateEvaluation } from "~/server/utils/evaluations";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
|
||||
export const evaluationsRouter = createTRPCRouter({
|
||||
list: publicProcedure.input(z.object({ experimentId: z.string() })).query(async ({ input }) => {
|
||||
@@ -24,7 +24,7 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
const evaluation = await prisma.evaluation.create({
|
||||
await prisma.evaluation.create({
|
||||
data: {
|
||||
experimentId: input.experimentId,
|
||||
label: input.label,
|
||||
@@ -32,7 +32,10 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
evalType: input.evalType,
|
||||
},
|
||||
});
|
||||
await reevaluateEvaluation(evaluation);
|
||||
|
||||
// TODO: this may be a bad UX for slow evals (eg. GPT-4 evals) Maybe need
|
||||
// to kick off a background job or something instead
|
||||
await runAllEvals(input.experimentId);
|
||||
}),
|
||||
|
||||
update: publicProcedure
|
||||
@@ -40,24 +43,30 @@ export const evaluationsRouter = createTRPCRouter({
|
||||
z.object({
|
||||
id: z.string(),
|
||||
updates: z.object({
|
||||
name: z.string().optional(),
|
||||
label: z.string().optional(),
|
||||
value: z.string().optional(),
|
||||
evalType: z.nativeEnum(EvalType).optional(),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
await prisma.evaluation.update({
|
||||
const evaluation = await prisma.evaluation.update({
|
||||
where: { id: input.id },
|
||||
data: {
|
||||
label: input.updates.name,
|
||||
label: input.updates.label,
|
||||
value: input.updates.value,
|
||||
evalType: input.updates.evalType,
|
||||
},
|
||||
});
|
||||
await reevaluateEvaluation(
|
||||
await prisma.evaluation.findUniqueOrThrow({ where: { id: input.id } }),
|
||||
);
|
||||
|
||||
await prisma.outputEvaluation.deleteMany({
|
||||
where: {
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
});
|
||||
// Re-run all evals. Other eval results will already be cached, so this
|
||||
// should only re-run the updated one.
|
||||
await runAllEvals(evaluation.experimentId);
|
||||
}),
|
||||
|
||||
delete: publicProcedure.input(z.object({ id: z.string() })).mutation(async ({ input }) => {
|
||||
|
||||
@@ -32,11 +32,43 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
throw new Error(`Prompt Variant with id ${input.variantId} does not exist`);
|
||||
}
|
||||
|
||||
const evalResults = await prisma.evaluationResult.findMany({
|
||||
where: {
|
||||
promptVariantId: input.variantId,
|
||||
const outputEvals = await prisma.outputEvaluation.groupBy({
|
||||
by: ["evaluationId"],
|
||||
_sum: {
|
||||
result: true,
|
||||
},
|
||||
include: { evaluation: true },
|
||||
_count: {
|
||||
id: true,
|
||||
},
|
||||
where: {
|
||||
modelOutput: {
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
id: input.variantId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: {
|
||||
experimentId: variant.experimentId,
|
||||
},
|
||||
});
|
||||
|
||||
const evalResults = evals.map((evalItem) => {
|
||||
const evalResult = outputEvals.find((outputEval) => outputEval.evaluationId === evalItem.id);
|
||||
return {
|
||||
id: evalItem.id,
|
||||
label: evalItem.label,
|
||||
passCount: evalResult?._sum?.result ?? 0,
|
||||
totalCount: evalResult?._count?.id ?? 1,
|
||||
};
|
||||
});
|
||||
|
||||
const scenarioCount = await prisma.testScenario.count({
|
||||
@@ -50,7 +82,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
||||
promptVariantId: input.variantId,
|
||||
testScenario: { visible: true },
|
||||
modelOutput: {
|
||||
isNot: null,
|
||||
is: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -21,7 +21,17 @@ export const scenarioVariantCellsRouter = createTRPCRouter({
|
||||
},
|
||||
},
|
||||
include: {
|
||||
modelOutput: true,
|
||||
modelOutput: {
|
||||
include: {
|
||||
outputEvaluation: {
|
||||
include: {
|
||||
evaluation: {
|
||||
select: { label: true },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -3,7 +3,7 @@ import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
|
||||
import { prisma } from "~/server/db";
|
||||
import { autogenerateScenarioValues } from "../autogen";
|
||||
import { recordExperimentUpdated } from "~/server/utils/recordExperimentUpdated";
|
||||
import { reevaluateAll } from "~/server/utils/evaluations";
|
||||
import { runAllEvals } from "~/server/utils/evaluations";
|
||||
import { generateNewCell } from "~/server/utils/generateNewCell";
|
||||
|
||||
export const scenariosRouter = createTRPCRouter({
|
||||
@@ -73,7 +73,7 @@ export const scenariosRouter = createTRPCRouter({
|
||||
});
|
||||
|
||||
// Reevaluate all evaluations now that this scenario is hidden
|
||||
await reevaluateAll(hiddenScenario.experimentId);
|
||||
await runAllEvals(hiddenScenario.experimentId);
|
||||
|
||||
return hiddenScenario;
|
||||
}),
|
||||
|
||||
@@ -6,7 +6,7 @@ import { type JSONSerializable } from "../types";
|
||||
import { sleep } from "../utils/sleep";
|
||||
import { shouldStream } from "../utils/shouldStream";
|
||||
import { generateChannel } from "~/utils/generateChannel";
|
||||
import { reevaluateVariant } from "../utils/evaluations";
|
||||
import { runEvalsForOutput } from "../utils/evaluations";
|
||||
import { constructPrompt } from "../utils/constructPrompt";
|
||||
import { type CompletionCreateParams } from "openai/resources/chat";
|
||||
import { type Prisma } from "@prisma/client";
|
||||
@@ -148,5 +148,7 @@ export const queryLLM = defineTask<queryLLMJob>("queryLLM", async (task) => {
|
||||
},
|
||||
});
|
||||
|
||||
await reevaluateVariant(cell.promptVariantId);
|
||||
if (modelOutput) {
|
||||
await runEvalsForOutput(variant.experimentId, scenario, modelOutput);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,105 +1,154 @@
|
||||
import { type ModelOutput, type Evaluation } from "@prisma/client";
|
||||
import { prisma } from "../db";
|
||||
import { evaluateOutput } from "./evaluateOutput";
|
||||
import { runOneEval } from "./runOneEval";
|
||||
import { type Scenario } from "~/components/OutputsTable/types";
|
||||
|
||||
export const reevaluateVariant = async (variantId: string) => {
|
||||
const variant = await prisma.promptVariant.findUnique({
|
||||
where: { id: variantId },
|
||||
});
|
||||
if (!variant) return;
|
||||
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId: variant.experimentId },
|
||||
});
|
||||
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
const saveResult = async (evaluation: Evaluation, scenario: Scenario, modelOutput: ModelOutput) => {
|
||||
const result = runOneEval(evaluation, scenario, modelOutput);
|
||||
return await prisma.outputEvaluation.upsert({
|
||||
where: {
|
||||
promptVariantId: variantId,
|
||||
retrievalStatus: "COMPLETE",
|
||||
testScenario: { visible: true },
|
||||
modelOutput: { isNot: null },
|
||||
modelOutputId_evaluationId: {
|
||||
modelOutputId: modelOutput.id,
|
||||
evaluationId: evaluation.id,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
modelOutputId: modelOutput.id,
|
||||
evaluationId: evaluation.id,
|
||||
result,
|
||||
},
|
||||
update: {
|
||||
result,
|
||||
},
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => {
|
||||
const passCount = cells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = cells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
evaluationId_promptVariantId: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variantId,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variantId,
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
update: {
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
||||
const variants = await prisma.promptVariant.findMany({
|
||||
where: { experimentId: evaluation.experimentId, visible: true },
|
||||
});
|
||||
|
||||
const cells = await prisma.scenarioVariantCell.findMany({
|
||||
where: {
|
||||
promptVariantId: { in: variants.map((v) => v.id) },
|
||||
testScenario: { visible: true },
|
||||
statusCode: { notIn: [429] },
|
||||
modelOutput: { isNot: null },
|
||||
},
|
||||
include: { testScenario: true, modelOutput: true },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
variants.map(async (variant) => {
|
||||
const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
||||
const passCount = variantCells.filter((cell) =>
|
||||
evaluateOutput(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
).length;
|
||||
const failCount = variantCells.length - passCount;
|
||||
|
||||
await prisma.evaluationResult.upsert({
|
||||
where: {
|
||||
evaluationId_promptVariantId: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variant.id,
|
||||
},
|
||||
},
|
||||
create: {
|
||||
evaluationId: evaluation.id,
|
||||
promptVariantId: variant.id,
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
update: {
|
||||
passCount,
|
||||
failCount,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
export const reevaluateAll = async (experimentId: string) => {
|
||||
export const runEvalsForOutput = async (
|
||||
experimentId: string,
|
||||
scenario: Scenario,
|
||||
modelOutput: ModelOutput,
|
||||
) => {
|
||||
const evaluations = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(evaluations.map(reevaluateEvaluation));
|
||||
await Promise.all(
|
||||
evaluations.map(async (evaluation) => await saveResult(evaluation, scenario, modelOutput)),
|
||||
);
|
||||
|
||||
// const cells = await prisma.scenarioVariantCell.findMany({
|
||||
// where: {
|
||||
// promptVariantId: variantId,
|
||||
// retrievalStatus: "COMPLETE",
|
||||
// testScenario: { visible: true },
|
||||
// },
|
||||
// include: { testScenario: true, modelOutput: { include: { OutputEvaluation: true } } },
|
||||
// });
|
||||
|
||||
// await Promise.all(
|
||||
// evaluations.map(async (evaluation) => {
|
||||
// const passCount = cells.filter((cell) =>
|
||||
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
// ).length;
|
||||
// const failCount = cells.length - passCount;
|
||||
|
||||
// await prisma.evaluationResult.upsert({
|
||||
// where: {
|
||||
// evaluationId_promptVariantId: {
|
||||
// evaluationId: evaluation.id,
|
||||
// promptVariantId: variantId,
|
||||
// },
|
||||
// },
|
||||
// create: {
|
||||
// evaluationId: evaluation.id,
|
||||
// promptVariantId: variantId,
|
||||
// passCount,
|
||||
// failCount,
|
||||
// },
|
||||
// update: {
|
||||
// passCount,
|
||||
// failCount,
|
||||
// },
|
||||
// });
|
||||
// }),
|
||||
// );
|
||||
};
|
||||
|
||||
export const runAllEvals = async (experimentId: string) => {
|
||||
const outputs = await prisma.modelOutput.findMany({
|
||||
where: {
|
||||
scenarioVariantCell: {
|
||||
promptVariant: {
|
||||
experimentId,
|
||||
visible: true,
|
||||
},
|
||||
testScenario: {
|
||||
visible: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
include: {
|
||||
scenarioVariantCell: {
|
||||
include: {
|
||||
testScenario: true,
|
||||
},
|
||||
},
|
||||
outputEvaluation: true,
|
||||
},
|
||||
});
|
||||
const evals = await prisma.evaluation.findMany({
|
||||
where: { experimentId },
|
||||
});
|
||||
|
||||
await Promise.all(
|
||||
outputs.map(async (output) => {
|
||||
const unrunEvals = evals.filter(
|
||||
(evaluation) => !output.outputEvaluation.find((e) => e.evaluationId === evaluation.id),
|
||||
);
|
||||
|
||||
await Promise.all(
|
||||
unrunEvals.map(async (evaluation) => {
|
||||
await saveResult(evaluation, output.scenarioVariantCell.testScenario, output);
|
||||
}),
|
||||
);
|
||||
}),
|
||||
);
|
||||
|
||||
// const cells = await prisma.scenarioVariantCell.findMany({
|
||||
// where: {
|
||||
// promptVariantId: { in: variants.map((v) => v.id) },
|
||||
// testScenario: { visible: true },
|
||||
// statusCode: { notIn: [429] },
|
||||
// },
|
||||
// include: { testScenario: true, modelOutput: true },
|
||||
// });
|
||||
|
||||
// await Promise.all(
|
||||
// variants.map(async (variant) => {
|
||||
// const variantCells = cells.filter((cell) => cell.promptVariantId === variant.id);
|
||||
// const passCount = variantCells.filter((cell) =>
|
||||
// runOneEval(cell.modelOutput as ModelOutput, cell.testScenario, evaluation),
|
||||
// ).length;
|
||||
// const failCount = variantCells.length - passCount;
|
||||
|
||||
// await prisma.evaluationResult.upsert({
|
||||
// where: {
|
||||
// evaluationId_promptVariantId: {
|
||||
// evaluationId: evaluation.id,
|
||||
// promptVariantId: variant.id,
|
||||
// },
|
||||
// },
|
||||
// create: {
|
||||
// evaluationId: evaluation.id,
|
||||
// promptVariantId: variant.id,
|
||||
// passCount,
|
||||
// failCount,
|
||||
// },
|
||||
// update: {
|
||||
// passCount,
|
||||
// failCount,
|
||||
// },
|
||||
// });
|
||||
// }),
|
||||
// );
|
||||
};
|
||||
|
||||
@@ -2,30 +2,31 @@ import { type Evaluation, type ModelOutput, type TestScenario } from "@prisma/cl
|
||||
import { type ChatCompletion } from "openai/resources/chat";
|
||||
import { type VariableMap, fillTemplate } from "./fillTemplate";
|
||||
|
||||
export const evaluateOutput = (
|
||||
modelOutput: ModelOutput,
|
||||
scenario: TestScenario,
|
||||
export const runOneEval = (
|
||||
evaluation: Evaluation,
|
||||
): boolean => {
|
||||
scenario: TestScenario,
|
||||
modelOutput: ModelOutput,
|
||||
): number => {
|
||||
const output = modelOutput.output as unknown as ChatCompletion;
|
||||
|
||||
const message = output?.choices?.[0]?.message;
|
||||
|
||||
if (!message) return false;
|
||||
if (!message) return 0;
|
||||
|
||||
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||
|
||||
const matchRegex = fillTemplate(evaluation.value, scenario.variableValues as VariableMap);
|
||||
|
||||
let match;
|
||||
let result;
|
||||
|
||||
switch (evaluation.evalType) {
|
||||
case "CONTAINS":
|
||||
match = stringifiedMessage.match(matchRegex) !== null;
|
||||
result = stringifiedMessage.match(matchRegex) !== null ? 1 : 0;
|
||||
break;
|
||||
case "DOES_NOT_CONTAIN":
|
||||
match = stringifiedMessage.match(matchRegex) === null;
|
||||
result = stringifiedMessage.match(matchRegex) === null ? 1 : 0;
|
||||
break;
|
||||
}
|
||||
|
||||
return match;
|
||||
return result;
|
||||
};
|
||||
Reference in New Issue
Block a user